mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
feat(agent): Release agent SDK (#1396)
This commit is contained in:
@@ -32,6 +32,7 @@ class DAGManager(BaseComponent):
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
self.dag_alias_map: Dict[str, str] = {}
|
||||
self._trigger_manager: Optional["DefaultTriggerManager"] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
@@ -58,12 +59,14 @@ class DAGManager(BaseComponent):
|
||||
"""Execute after the application starts."""
|
||||
self.load_dags()
|
||||
|
||||
def register_dag(self, dag: DAG):
|
||||
def register_dag(self, dag: DAG, alias_name: Optional[str] = None):
|
||||
"""Register a DAG."""
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist")
|
||||
self.dag_map[dag_id] = dag
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
@@ -77,7 +80,22 @@ class DAGManager(BaseComponent):
|
||||
if dag_id not in self.dag_map:
|
||||
raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist")
|
||||
dag = self.dag_map[dag_id]
|
||||
# Clear the alias map
|
||||
for alias_name, _dag_id in self.dag_alias_map.items():
|
||||
if _dag_id == dag_id:
|
||||
del self.dag_alias_map[alias_name]
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
del self.dag_map[dag_id]
|
||||
|
||||
def get_dag(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
) -> Optional[DAG]:
|
||||
"""Get a DAG by dag_id or alias_name."""
|
||||
if dag_id and dag_id in self.dag_map:
|
||||
return self.dag_map[dag_id]
|
||||
if alias_name in self.dag_alias_map:
|
||||
return self.dag_map.get(self.dag_alias_map[alias_name])
|
||||
return None
|
||||
|
@@ -46,6 +46,13 @@ def _get_type_name(type_: Type[Any]) -> str:
|
||||
return type_name
|
||||
|
||||
|
||||
def _register_alias_types(type_: Type[Any], alias_ids: Optional[List[str]] = None):
|
||||
if alias_ids:
|
||||
for alias_id in alias_ids:
|
||||
if alias_id not in _TYPE_REGISTRY:
|
||||
_TYPE_REGISTRY[alias_id] = type_
|
||||
|
||||
|
||||
def _get_type_cls(type_name: str) -> Type[Any]:
|
||||
"""Get the type class by the type name.
|
||||
|
||||
@@ -58,9 +65,15 @@ def _get_type_cls(type_name: str) -> Type[Any]:
|
||||
Raises:
|
||||
ValueError: If the type is not registered.
|
||||
"""
|
||||
if type_name not in _TYPE_REGISTRY:
|
||||
from .compat import get_new_class_name
|
||||
|
||||
new_cls = get_new_class_name(type_name)
|
||||
if type_name in _TYPE_REGISTRY:
|
||||
return _TYPE_REGISTRY[type_name]
|
||||
elif new_cls and new_cls in _TYPE_REGISTRY:
|
||||
return _TYPE_REGISTRY[new_cls]
|
||||
else:
|
||||
raise ValueError(f"Type {type_name} not registered.")
|
||||
return _TYPE_REGISTRY[type_name]
|
||||
|
||||
|
||||
# Register the basic types.
|
||||
@@ -734,6 +747,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata):
|
||||
values["id"] = values["flow_type"] + "_" + values["type_cls"]
|
||||
return values
|
||||
|
||||
def new_alias(self, alias: Optional[List[str]] = None) -> List[str]:
|
||||
"""Get the new alias id."""
|
||||
if not alias:
|
||||
return []
|
||||
return [f"{self.flow_type}_{a}" for a in alias]
|
||||
|
||||
|
||||
def register_resource(
|
||||
label: str,
|
||||
@@ -742,6 +761,7 @@ def register_resource(
|
||||
parameters: Optional[List[Parameter]] = None,
|
||||
description: Optional[str] = None,
|
||||
resource_type: ResourceType = ResourceType.INSTANCE,
|
||||
alias: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Register the resource.
|
||||
@@ -755,6 +775,9 @@ def register_resource(
|
||||
description (Optional[str], optional): The description of the resource.
|
||||
Defaults to None.
|
||||
resource_type (ResourceType, optional): The type of the resource.
|
||||
alias (Optional[List[str]], optional): The alias of the resource. Defaults to
|
||||
None. For compatibility, we can use the alias to register the resource.
|
||||
|
||||
"""
|
||||
if resource_type == ResourceType.CLASS and parameters:
|
||||
raise ValueError("Class resource can't have parameters.")
|
||||
@@ -784,7 +807,9 @@ def register_resource(
|
||||
resource_type=resource_type,
|
||||
**kwargs,
|
||||
)
|
||||
_register_resource(cls, resource_metadata)
|
||||
alias_ids = resource_metadata.new_alias(alias)
|
||||
_register_alias_types(cls, alias_ids)
|
||||
_register_resource(cls, resource_metadata, alias_ids)
|
||||
# Attach the metadata to the class
|
||||
cls._resource_metadata = resource_metadata
|
||||
return cls
|
||||
@@ -949,11 +974,19 @@ class FlowRegistry:
|
||||
self._registry: Dict[str, _RegistryItem] = {}
|
||||
|
||||
def register_flow(
|
||||
self, view_cls: Type, metadata: Union[ViewMetadata, ResourceMetadata]
|
||||
self,
|
||||
view_cls: Type,
|
||||
metadata: Union[ViewMetadata, ResourceMetadata],
|
||||
alias_ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""Register the operator."""
|
||||
key = metadata.id
|
||||
self._registry[key] = _RegistryItem(key=key, cls=view_cls, metadata=metadata)
|
||||
if alias_ids:
|
||||
for alias_id in alias_ids:
|
||||
self._registry[alias_id] = _RegistryItem(
|
||||
key=alias_id, cls=view_cls, metadata=metadata
|
||||
)
|
||||
|
||||
def get_registry_item(self, key: str) -> Optional[_RegistryItem]:
|
||||
"""Get the registry item by the key."""
|
||||
@@ -998,6 +1031,10 @@ def _get_resource_class(type_key: str) -> _RegistryItem:
|
||||
return item
|
||||
|
||||
|
||||
def _register_resource(cls: Type, resource_metadata: ResourceMetadata):
|
||||
def _register_resource(
|
||||
cls: Type,
|
||||
resource_metadata: ResourceMetadata,
|
||||
alias_ids: Optional[List[str]] = None,
|
||||
):
|
||||
"""Register the operator."""
|
||||
_OPERATOR_REGISTRY.register_flow(cls, resource_metadata)
|
||||
_OPERATOR_REGISTRY.register_flow(cls, resource_metadata, alias_ids)
|
||||
|
40
dbgpt/core/awel/flow/compat.py
Normal file
40
dbgpt/core/awel/flow/compat.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Compatibility mapping for flow classes."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
|
||||
|
||||
|
||||
_OLD_AGENT_RESOURCE_MODULE = "dbgpt.serve.agent.team.layout.agent_operator_resource"
|
||||
_NEW_AGENT_RESOURCE_MODULE = "dbgpt.agent.plan.awel.agent_operator_resource"
|
||||
|
||||
|
||||
def _register(
|
||||
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
|
||||
):
|
||||
if not new_name:
|
||||
new_name = old_name
|
||||
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
|
||||
|
||||
|
||||
def get_new_class_name(old_class_name: str) -> Optional[str]:
|
||||
"""Get the new class name for the old class name."""
|
||||
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
|
||||
return new_cls_name
|
||||
|
||||
|
||||
_register(
|
||||
_OLD_AGENT_RESOURCE_MODULE,
|
||||
_NEW_AGENT_RESOURCE_MODULE,
|
||||
"AwelAgentResource",
|
||||
"AWELAgentResource",
|
||||
)
|
||||
_register(
|
||||
_OLD_AGENT_RESOURCE_MODULE,
|
||||
_NEW_AGENT_RESOURCE_MODULE,
|
||||
"AwelAgentConfig",
|
||||
"AWELAgentConfig",
|
||||
)
|
||||
_register(
|
||||
_OLD_AGENT_RESOURCE_MODULE, _NEW_AGENT_RESOURCE_MODULE, "AwelAgent", "AWELAgent"
|
||||
)
|
@@ -17,6 +17,7 @@ from .base import (
|
||||
_get_operator_class,
|
||||
_get_resource_class,
|
||||
)
|
||||
from .compat import get_new_class_name
|
||||
from .exceptions import (
|
||||
FlowClassMetadataException,
|
||||
FlowDAGMetadataException,
|
||||
@@ -607,9 +608,26 @@ class FlowFactory:
|
||||
f"{metadata_cls}"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise FlowClassMetadataException(
|
||||
f"Import {node_data.type_cls} failed: {e}"
|
||||
)
|
||||
raise_error = True
|
||||
new_type_cls: Optional[str] = None
|
||||
try:
|
||||
new_type_cls = get_new_class_name(node_data.type_cls)
|
||||
if new_type_cls:
|
||||
metadata_cls = import_from_string(new_type_cls)
|
||||
logger.info(
|
||||
f"Import {new_type_cls} successfully, metadata_cls is : "
|
||||
f"{metadata_cls}"
|
||||
)
|
||||
raise_error = False
|
||||
except ImportError as ex:
|
||||
raise FlowClassMetadataException(
|
||||
f"Import {node_data.type_cls} with new type {new_type_cls} "
|
||||
f"failed: {ex}"
|
||||
)
|
||||
if raise_error:
|
||||
raise FlowClassMetadataException(
|
||||
f"Import {node_data.type_cls} failed: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _topological_sort(
|
||||
|
Reference in New Issue
Block a user