mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict
|
||||
from dbgpt.core.awel.flow.flow_factory import State
|
||||
from dbgpt.core.interface.variables import StorageVariablesProvider
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
from ..api.schemas import (
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
@@ -43,6 +49,7 @@ class ServeEntity(Model):
|
||||
editable = Column(
|
||||
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
|
||||
)
|
||||
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -74,6 +81,57 @@ class ServeEntity(Model):
|
||||
return editable is None or editable == 0
|
||||
|
||||
|
||||
class VariablesEntity(Model):
|
||||
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
key = Column(String(128), index=True, nullable=False, comment="Variable key")
|
||||
name = Column(String(128), index=True, nullable=True, comment="Variable name")
|
||||
label = Column(String(128), nullable=True, comment="Variable label")
|
||||
value = Column(Text, nullable=True, comment="Variable value, JSON format")
|
||||
value_type = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable value type(string, int, float, bool)",
|
||||
)
|
||||
category = Column(
|
||||
String(32),
|
||||
default="common",
|
||||
nullable=True,
|
||||
comment="Variable category(common or secret)",
|
||||
)
|
||||
encryption_method = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable encryption method(fernet, simple, rsa, aes)",
|
||||
)
|
||||
salt = Column(String(128), nullable=True, comment="Variable salt")
|
||||
scope = Column(
|
||||
String(32),
|
||||
default="global",
|
||||
nullable=True,
|
||||
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
|
||||
"etc)",
|
||||
)
|
||||
scope_key = Column(
|
||||
String(256),
|
||||
nullable=True,
|
||||
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
|
||||
"the scope_key is dag id of flow",
|
||||
)
|
||||
enabled = Column(
|
||||
Integer,
|
||||
default=1,
|
||||
nullable=True,
|
||||
comment="Variable enabled, 0: disabled, 1: enabled",
|
||||
)
|
||||
description = Column(Text, nullable=True, comment="Variable description")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for Flow"""
|
||||
|
||||
@@ -98,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
error_message = request_dict.get("error_message")
|
||||
if error_message:
|
||||
error_message = error_message[:500]
|
||||
|
||||
variables_raw = request_dict.get("variables")
|
||||
variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
|
||||
)
|
||||
new_dict = {
|
||||
"uid": request_dict.get("uid"),
|
||||
"dag_id": request_dict.get("dag_id"),
|
||||
@@ -113,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"define_type": request_dict.get("define_type"),
|
||||
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
|
||||
"description": request_dict.get("description"),
|
||||
"variables": variables,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
}
|
||||
@@ -129,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
REQ: The request
|
||||
"""
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServeRequest(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -144,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
define_type=entity.define_type,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
description=entity.description,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
@@ -160,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServerResponse(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -175,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
version=entity.version,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
define_type=entity.define_type,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
@@ -215,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.define_type:
|
||||
entry.define_type = update_request.define_type
|
||||
|
||||
if update_request.variables:
|
||||
variables_raw = update_request.get_variables_dict()
|
||||
entry.variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False)
|
||||
if variables_raw
|
||||
else None
|
||||
)
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
@@ -222,3 +300,111 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
session.merge(entry)
|
||||
session.commit()
|
||||
return self.get_one(query_request)
|
||||
|
||||
|
||||
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
|
||||
"""The DAO class for Variables"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(
|
||||
self, request: Union[VariablesRequest, Dict[str, Any]]
|
||||
) -> VariablesEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[VariablesRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
model_to_dict(request) if isinstance(request, VariablesRequest) else request
|
||||
)
|
||||
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
|
||||
enabled = 1 if request_dict.get("enabled", True) else 0
|
||||
new_dict = {
|
||||
"key": request_dict.get("key"),
|
||||
"name": request_dict.get("name"),
|
||||
"label": request_dict.get("label"),
|
||||
"value": value,
|
||||
"value_type": request_dict.get("value_type"),
|
||||
"category": request_dict.get("category"),
|
||||
"encryption_method": request_dict.get("encryption_method"),
|
||||
"salt": request_dict.get("salt"),
|
||||
"scope": request_dict.get("scope"),
|
||||
"scope_key": request_dict.get("scope_key"),
|
||||
"enabled": enabled,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
"description": request_dict.get("description"),
|
||||
}
|
||||
entity = VariablesEntity(**new_dict)
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesRequest(
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
description=entity.description,
|
||||
)
|
||||
|
||||
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesResponse(
|
||||
id=entity.id,
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
description=entity.description,
|
||||
)
|
||||
|
71
dbgpt/serve/flow/models/variables_adapter.py
Normal file
71
dbgpt/serve/flow/models/variables_adapter.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
|
||||
|
||||
from .models import VariablesEntity
|
||||
|
||||
|
||||
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
|
||||
"""Variables adapter.
|
||||
|
||||
Convert between storage format and database model.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
|
||||
"""Convert to storage format."""
|
||||
return VariablesEntity(
|
||||
key=item.key,
|
||||
name=item.name,
|
||||
label=item.label,
|
||||
value=item.value,
|
||||
value_type=item.value_type,
|
||||
category=item.category,
|
||||
encryption_method=item.encryption_method,
|
||||
salt=item.salt,
|
||||
scope=item.scope,
|
||||
scope_key=item.scope_key,
|
||||
sys_code=item.sys_code,
|
||||
user_name=item.user_name,
|
||||
description=item.description,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
|
||||
"""Convert from storage format."""
|
||||
return StorageVariables(
|
||||
key=model.key,
|
||||
name=model.name,
|
||||
label=model.label,
|
||||
value=model.value,
|
||||
value_type=model.value_type,
|
||||
category=model.category,
|
||||
encryption_method=model.encryption_method,
|
||||
salt=model.salt,
|
||||
scope=model.scope,
|
||||
scope_key=model.scope_key,
|
||||
sys_code=model.sys_code,
|
||||
user_name=model.user_name,
|
||||
description=model.description,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[VariablesEntity],
|
||||
resource_id: VariablesIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get query for identifier."""
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(VariablesEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
|
||||
|
||||
# enabled must be True
|
||||
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
|
||||
return query_obj
|
Reference in New Issue
Block a user