DB-GPT/dbgpt/serve/flow/models/models.py
Fangyin Cheng 9502251c08
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
2024-08-23 14:57:54 +08:00

411 lines
16 KiB
Python

"""This is an auto-generated model file
You can define your own models and DAOs here
"""
import json
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import model_to_dict
from dbgpt.core.awel.flow.flow_factory import State
from dbgpt.core.interface.variables import StorageVariablesProvider
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from ..api.schemas import (
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
__table_args__ = (UniqueConstraint("uid", name="uk_uid"),)
id = Column(Integer, primary_key=True, comment="Auto increment id")
uid = Column(String(128), index=True, nullable=False, comment="Unique id")
dag_id = Column(String(128), index=True, nullable=True, comment="DAG id")
label = Column(String(128), nullable=True, comment="Flow label")
name = Column(String(128), index=True, nullable=True, comment="Flow name")
flow_category = Column(String(64), nullable=True, comment="Flow category")
flow_data = Column(Text, nullable=True, comment="Flow data, JSON format")
description = Column(String(512), nullable=True, comment="Flow description")
state = Column(String(32), nullable=True, comment="Flow state")
error_message = Column(String(512), nullable=True, comment="Error message")
source = Column(String(64), nullable=True, comment="Flow source")
source_url = Column(String(512), nullable=True, comment="Flow source url")
version = Column(String(32), nullable=True, comment="Flow version")
define_type = Column(
String(32),
default="json",
nullable=True,
comment="Flow define type(json or python)",
)
editable = Column(
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
)
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return (
f"ServeEntity(id={self.id}, uid={self.uid}, dag_id={self.dag_id}, name={self.name}, "
f"flow_data={self.flow_data}, user_name={self.user_name}, "
f"sys_code={self.sys_code}, gmt_created={self.gmt_created}, "
f"gmt_modified={self.gmt_modified})"
)
@classmethod
def parse_editable(cls, editable: Any) -> int:
"""Parse editable"""
if editable is None:
return 0
if isinstance(editable, bool):
return 0 if editable else 1
elif isinstance(editable, int):
return 0 if editable == 0 else 1
else:
raise ValueError(f"Invalid editable: {editable}")
@classmethod
def to_bool_editable(cls, editable: int) -> bool:
"""Convert editable to bool"""
return editable is None or editable == 0
class VariablesEntity(Model):
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
key = Column(String(128), index=True, nullable=False, comment="Variable key")
name = Column(String(128), index=True, nullable=True, comment="Variable name")
label = Column(String(128), nullable=True, comment="Variable label")
value = Column(Text, nullable=True, comment="Variable value, JSON format")
value_type = Column(
String(32),
nullable=True,
comment="Variable value type(string, int, float, bool)",
)
category = Column(
String(32),
default="common",
nullable=True,
comment="Variable category(common or secret)",
)
encryption_method = Column(
String(32),
nullable=True,
comment="Variable encryption method(fernet, simple, rsa, aes)",
)
salt = Column(String(128), nullable=True, comment="Variable salt")
scope = Column(
String(32),
default="global",
nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
"etc)",
)
scope_key = Column(
String(256),
nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
"the scope_key is dag id of flow",
)
enabled = Column(
Integer,
default=1,
nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled",
)
description = Column(Text, nullable=True, comment="Variable description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Flow"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, ServeRequest) else request
)
flow_data = json.dumps(request_dict.get("flow_data"), ensure_ascii=False)
state = request_dict.get("state", State.INITIALIZING.value)
error_message = request_dict.get("error_message")
if error_message:
error_message = error_message[:500]
variables_raw = request_dict.get("variables")
variables = (
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
)
new_dict = {
"uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"),
"label": request_dict.get("label"),
"name": request_dict.get("name"),
"flow_category": request_dict.get("flow_category"),
"flow_data": flow_data,
"state": state,
"error_message": error_message,
"source": request_dict.get("source"),
"source_url": request_dict.get("source_url"),
"version": request_dict.get("version"),
"define_type": request_dict.get("define_type"),
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
"description": request_dict.get("description"),
"variables": variables,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
}
entity = ServeEntity(**new_dict)
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
flow_data = json.loads(entity.flow_data)
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServeRequest(
uid=entity.uid,
dag_id=entity.dag_id,
label=entity.label,
name=entity.name,
flow_category=entity.flow_category,
flow_data=flow_data,
state=State.value_of(entity.state),
error_message=entity.error_message,
source=entity.source,
source_url=entity.source_url,
version=entity.version,
define_type=entity.define_type,
editable=ServeEntity.to_bool_editable(entity.editable),
description=entity.description,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
flow_data = json.loads(entity.flow_data)
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServerResponse(
uid=entity.uid,
dag_id=entity.dag_id,
label=entity.label,
name=entity.name,
flow_category=entity.flow_category,
flow_data=flow_data,
description=entity.description,
state=State.value_of(entity.state),
error_message=entity.error_message,
source=entity.source,
source_url=entity.source_url,
version=entity.version,
editable=ServeEntity.to_bool_editable(entity.editable),
define_type=entity.define_type,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
)
def update(
self, query_request: QUERY_SPEC, update_request: ServeRequest
) -> ServerResponse:
with self.session(commit=False) as session:
query = self._create_query_object(session, query_request)
entry: ServeEntity = query.first()
if entry is None:
raise Exception("Invalid request")
if update_request.label:
entry.label = update_request.label
if update_request.name:
entry.name = update_request.name
if update_request.flow_category:
entry.flow_category = update_request.flow_category
if update_request.flow_data:
entry.flow_data = json.dumps(
model_to_dict(update_request.flow_data), ensure_ascii=False
)
if update_request.description:
entry.description = update_request.description
if update_request.state:
entry.state = update_request.state.value
if update_request.error_message is not None:
# Keep first 500 characters
entry.error_message = update_request.error_message[:500]
if update_request.source:
entry.source = update_request.source
if update_request.source_url:
entry.source_url = update_request.source_url
if update_request.version:
entry.version = update_request.version
entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.define_type:
entry.define_type = update_request.define_type
if update_request.variables:
variables_raw = update_request.get_variables_dict()
entry.variables = (
json.dumps(variables_raw, ensure_ascii=False)
if variables_raw
else None
)
if update_request.user_name:
entry.user_name = update_request.user_name
if update_request.sys_code:
entry.sys_code = update_request.sys_code
session.merge(entry)
session.commit()
return self.get_one(query_request)
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
"""The DAO class for Variables"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(
self, request: Union[VariablesRequest, Dict[str, Any]]
) -> VariablesEntity:
"""Convert the request to an entity
Args:
request (Union[VariablesRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, VariablesRequest) else request
)
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
enabled = 1 if request_dict.get("enabled", True) else 0
new_dict = {
"key": request_dict.get("key"),
"name": request_dict.get("name"),
"label": request_dict.get("label"),
"value": value,
"value_type": request_dict.get("value_type"),
"category": request_dict.get("category"),
"encryption_method": request_dict.get("encryption_method"),
"salt": request_dict.get("salt"),
"scope": request_dict.get("scope"),
"scope_key": request_dict.get("scope_key"),
"enabled": enabled,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
"description": request_dict.get("description"),
}
entity = VariablesEntity(**new_dict)
return entity
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
enabled = entity.enabled == 1
return VariablesRequest(
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
description=entity.description,
)
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
enabled = entity.enabled == 1
return VariablesResponse(
id=entity.id,
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
description=entity.description,
)