feat(core): AWEL flow 2.0 backend code (#1879)

Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
Fangyin Cheng
2024-08-23 14:57:54 +08:00
committed by GitHub
parent 3a32344380
commit 9502251c08
67 changed files with 8289 additions and 190 deletions

View File

@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import model_to_dict
from dbgpt.core.awel.flow.flow_factory import State
from dbgpt.core.interface.variables import StorageVariablesProvider
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
from ..api.schemas import (
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
class ServeEntity(Model):
@@ -43,6 +49,7 @@ class ServeEntity(Model):
editable = Column(
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
)
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -74,6 +81,57 @@ class ServeEntity(Model):
return editable is None or editable == 0
class VariablesEntity(Model):
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
key = Column(String(128), index=True, nullable=False, comment="Variable key")
name = Column(String(128), index=True, nullable=True, comment="Variable name")
label = Column(String(128), nullable=True, comment="Variable label")
value = Column(Text, nullable=True, comment="Variable value, JSON format")
value_type = Column(
String(32),
nullable=True,
comment="Variable value type(string, int, float, bool)",
)
category = Column(
String(32),
default="common",
nullable=True,
comment="Variable category(common or secret)",
)
encryption_method = Column(
String(32),
nullable=True,
comment="Variable encryption method(fernet, simple, rsa, aes)",
)
salt = Column(String(128), nullable=True, comment="Variable salt")
scope = Column(
String(32),
default="global",
nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
"etc)",
)
scope_key = Column(
String(256),
nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
"the scope_key is dag id of flow",
)
enabled = Column(
Integer,
default=1,
nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled",
)
description = Column(Text, nullable=True, comment="Variable description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Flow"""
@@ -98,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
error_message = request_dict.get("error_message")
if error_message:
error_message = error_message[:500]
variables_raw = request_dict.get("variables")
variables = (
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
)
new_dict = {
"uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"),
@@ -113,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"define_type": request_dict.get("define_type"),
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
"description": request_dict.get("description"),
"variables": variables,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
}
@@ -129,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
REQ: The request
"""
flow_data = json.loads(entity.flow_data)
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServeRequest(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -144,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
define_type=entity.define_type,
editable=ServeEntity.to_bool_editable(entity.editable),
description=entity.description,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -160,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_data = json.loads(entity.flow_data)
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServerResponse(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -175,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
version=entity.version,
editable=ServeEntity.to_bool_editable(entity.editable),
define_type=entity.define_type,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
@@ -215,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.define_type:
entry.define_type = update_request.define_type
if update_request.variables:
variables_raw = update_request.get_variables_dict()
entry.variables = (
json.dumps(variables_raw, ensure_ascii=False)
if variables_raw
else None
)
if update_request.user_name:
entry.user_name = update_request.user_name
if update_request.sys_code:
@@ -222,3 +300,111 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
session.merge(entry)
session.commit()
return self.get_one(query_request)
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
"""The DAO class for Variables"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(
self, request: Union[VariablesRequest, Dict[str, Any]]
) -> VariablesEntity:
"""Convert the request to an entity
Args:
request (Union[VariablesRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, VariablesRequest) else request
)
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
enabled = 1 if request_dict.get("enabled", True) else 0
new_dict = {
"key": request_dict.get("key"),
"name": request_dict.get("name"),
"label": request_dict.get("label"),
"value": value,
"value_type": request_dict.get("value_type"),
"category": request_dict.get("category"),
"encryption_method": request_dict.get("encryption_method"),
"salt": request_dict.get("salt"),
"scope": request_dict.get("scope"),
"scope_key": request_dict.get("scope_key"),
"enabled": enabled,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
"description": request_dict.get("description"),
}
entity = VariablesEntity(**new_dict)
return entity
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
enabled = entity.enabled == 1
return VariablesRequest(
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
description=entity.description,
)
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
enabled = entity.enabled == 1
return VariablesResponse(
id=entity.id,
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
description=entity.description,
)

View File

@@ -0,0 +1,71 @@
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.storage import StorageItemAdapter
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
from .models import VariablesEntity
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
"""Variables adapter.
Convert between storage format and database model.
"""
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
"""Convert to storage format."""
return VariablesEntity(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
value_type=item.value_type,
category=item.category,
encryption_method=item.encryption_method,
salt=item.salt,
scope=item.scope,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
"""Convert from storage format."""
return StorageVariables(
key=model.key,
name=model.name,
label=model.label,
value=model.value,
value_type=model.value_type,
category=model.category,
encryption_method=model.encryption_method,
salt=model.salt,
scope=model.scope,
scope_key=model.scope_key,
sys_code=model.sys_code,
user_name=model.user_name,
description=model.description,
)
def get_query_for_identifier(
self,
storage_format: Type[VariablesEntity],
resource_id: VariablesIdentifier,
**kwargs,
):
"""Get query for identifier."""
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(VariablesEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
# enabled must be True
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
return query_obj