feat(core): AWEL flow 2.0 backend code (#1879)

Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
Fangyin Cheng
2024-08-23 14:57:54 +08:00
committed by GitHub
parent 3a32344380
commit 9502251c08
67 changed files with 8289 additions and 190 deletions

View File

@@ -1,18 +1,29 @@
import io
import json
from functools import cache
from typing import List, Optional, Union
from typing import Dict, List, Literal, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import JSONResponse, StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.serve.core import Result
from dbgpt.serve.core import Result, blocking_func_to_async
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
from ..service.variables_service import VariablesService
from .schemas import (
FlowDebugRequest,
RefreshNodeRequest,
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
router = APIRouter()
@@ -23,7 +34,12 @@ global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
return Service.get_instance(global_system_app)
def get_variable_service() -> VariablesService:
"""Get the service instance"""
return VariablesService.get_instance(global_system_app)
get_bearer_token = HTTPBearer(auto_error=False)
@@ -102,7 +118,9 @@ async def test_auth():
@router.post(
"/flows", response_model=Result[None], dependencies=[Depends(check_api_key)]
"/flows",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
@@ -239,20 +257,236 @@ async def query_page(
@router.get("/nodes", dependencies=[Depends(check_api_key)])
async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]:
async def get_nodes(
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
tags: Optional[str] = Query(default=None, description="tags"),
):
"""Get the operator or resource nodes
Args:
user_name (Optional[str]): The username
sys_code (Optional[str]): The system code
tags (Optional[str]): The tags encoded in JSON format
Returns:
Result[List[Union[ViewMetadata, ResourceMetadata]]]:
The operator or resource nodes
"""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
return Result.succ(_OPERATOR_REGISTRY.metadata_list())
tags_dict: Optional[Dict[str, str]] = None
if tags:
try:
tags_dict = json.loads(tags)
except json.JSONDecodeError:
return Result.fail("Invalid JSON format for tags")
metadata_list = await blocking_func_to_async(
global_system_app,
_OPERATOR_REGISTRY.metadata_list,
tags_dict,
user_name,
sys_code,
)
return Result.succ(metadata_list)
@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)])
async def refresh_nodes(refresh_request: RefreshNodeRequest):
"""Refresh the operator or resource nodes
Returns:
Result[None]: The response
"""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
# Make sure the variables provider is initialized
_ = get_variable_service().variables_provider
new_metadata = await _OPERATOR_REGISTRY.refresh(
refresh_request.id,
refresh_request.flow_type == "operator",
refresh_request.refresh,
"http",
global_system_app,
)
return Result.succ(new_metadata)
@router.post(
"/variables",
response_model=Result[VariablesResponse],
dependencies=[Depends(check_api_key)],
)
async def create_variables(
variables_request: VariablesRequest,
) -> Result[VariablesResponse]:
"""Create a new Variables entity
Args:
variables_request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
res = await blocking_func_to_async(
global_system_app, get_variable_service().create, variables_request
)
return Result.succ(res)
@router.put(
"/variables/{v_id}",
response_model=Result[VariablesResponse],
dependencies=[Depends(check_api_key)],
)
async def update_variables(
v_id: int, variables_request: VariablesRequest
) -> Result[VariablesResponse]:
"""Update a Variables entity
Args:
v_id (int): The variable id
variables_request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
res = await blocking_func_to_async(
global_system_app, get_variable_service().update, v_id, variables_request
)
return Result.succ(res)
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
async def debug_flow(
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
):
"""Run the flow in debug mode."""
# Return the no-incremental stream by default
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
service._wrapper_chat_stream_flow_str(stream_iter),
headers=headers,
media_type="text/event-stream",
)
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
async def export_flow(
uid: str,
export_type: Literal["json", "dbgpts"] = Query(
"json", description="export type(json or dbgpts)"
),
format: Literal["file", "json"] = Query(
"file", description="response format(file or json)"
),
file_name: Optional[str] = Query(default=None, description="file name to export"),
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
service: Service = Depends(get_service),
):
"""Export the flow to a file."""
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
if not flow:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
package_name = flow.name.replace("_", "-")
file_name = file_name or package_name
if export_type == "json":
flow_dict = {"flow": flow.to_dict()}
if format == "json":
return JSONResponse(content=flow_dict)
else:
# Return the json file
return StreamingResponse(
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
media_type="application/file",
headers={
"Content-Disposition": f"attachment;filename={file_name}.json"
},
)
elif export_type == "dbgpts":
from ..service.share_utils import _generate_dbgpts_zip
if format == "json":
raise HTTPException(
status_code=400, detail="json response is not supported for dbgpts"
)
zip_buffer = await blocking_func_to_async(
global_system_app, _generate_dbgpts_zip, package_name, flow
)
return StreamingResponse(
zip_buffer,
media_type="application/x-zip-compressed",
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
)
@router.post(
"/flow/import",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def import_flow(
file: UploadFile = File(...),
save_flow: bool = Query(
False, description="Whether to save the flow after importing"
),
service: Service = Depends(get_service),
):
"""Import the flow from a file."""
filename = file.filename
file_extension = filename.split(".")[-1].lower()
if file_extension == "json":
# Handle json file
json_content = await file.read()
json_dict = json.loads(json_content)
if "flow" not in json_dict:
raise HTTPException(
status_code=400, detail="invalid json file, missing 'flow' key"
)
flow = ServeRequest.parse_obj(json_dict["flow"])
elif file_extension == "zip":
from ..service.share_utils import _parse_flow_from_zip_file
# Handle zip file
flow = await _parse_flow_from_zip_file(file, global_system_app)
else:
raise HTTPException(
status_code=400, detail=f"invalid file extension {file_extension}"
)
if save_flow:
return Result.succ(service.create_and_save_dag(flow))
else:
return Result.succ(flow)
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
from .variables_provider import (
BuiltinAllSecretVariablesProvider,
BuiltinAllVariablesProvider,
BuiltinEmbeddingsVariablesProvider,
BuiltinFlowVariablesProvider,
BuiltinLLMVariablesProvider,
BuiltinNodeVariablesProvider,
)
global global_system_app
system_app.register(Service)
system_app.register(VariablesService)
system_app.register(BuiltinFlowVariablesProvider)
system_app.register(BuiltinNodeVariablesProvider)
system_app.register(BuiltinAllVariablesProvider)
system_app.register(BuiltinAllSecretVariablesProvider)
system_app.register(BuiltinLLMVariablesProvider)
system_app.register(BuiltinEmbeddingsVariablesProvider)
global_system_app = system_app

View File

@@ -1,7 +1,9 @@
from dbgpt._private.pydantic import ConfigDict
from typing import Any, Dict, List, Literal, Optional, Union
# Define your Pydantic schemas here
from dbgpt.core.awel.flow.flow_factory import FlowPanel
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core.awel import CommonLLMHttpRequestBody
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
from ..config import SERVE_APP_NAME_HUMP
@@ -14,3 +16,71 @@ class ServerResponse(FlowPanel):
# TODO define your own fields here
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
class VariablesResponse(VariablesRequest):
"""Variable response model."""
id: int = Field(
...,
description="The id of the variable",
examples=[1],
)
class RefreshNodeRequest(BaseModel):
"""Flow response model"""
model_config = ConfigDict(title=f"RefreshNodeRequest")
id: str = Field(
...,
title="The id of the node",
description="The id of the node to refresh",
examples=["operator_llm_operator___$$___llm___$$___v1"],
)
flow_type: Literal["operator", "resource"] = Field(
"operator",
title="The type of the node",
description="The type of the node to refresh",
examples=["operator", "resource"],
)
type_name: str = Field(
...,
title="The type of the node",
description="The type of the node to refresh",
examples=["LLMOperator"],
)
type_cls: str = Field(
...,
title="The class of the node",
description="The class of the node to refresh",
examples=["dbgpt.core.operator.llm.LLMOperator"],
)
refresh: List[RefreshOptionRequest] = Field(
...,
title="The refresh options",
description="The refresh options",
)
class FlowDebugRequest(BaseModel):
"""Flow response model"""
model_config = ConfigDict(title=f"FlowDebugRequest")
flow: ServeRequest = Field(
...,
title="The flow to debug",
description="The flow to debug",
)
request: Union[CommonLLMHttpRequestBody, Dict[str, Any]] = Field(
...,
title="The request to debug",
description="The request to debug",
)
variables: Optional[Dict[str, Any]] = Field(
None,
title="The variables to debug",
description="The variables to debug",
)
user_name: Optional[str] = Field(None, description="User name")
sys_code: Optional[str] = Field(None, description="System code")

View File

@@ -0,0 +1,260 @@
from typing import List, Literal, Optional
from dbgpt.core.interface.variables import (
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
BUILTIN_VARIABLES_CORE_FLOW_NODES,
BUILTIN_VARIABLES_CORE_FLOWS,
BUILTIN_VARIABLES_CORE_LLMS,
BUILTIN_VARIABLES_CORE_SECRETS,
BUILTIN_VARIABLES_CORE_VARIABLES,
BuiltinVariablesProvider,
StorageVariables,
)
from ..service.service import Service
from .endpoints import get_service, get_variable_service
from .schemas import ServerResponse
class BuiltinFlowVariablesProvider(BuiltinVariablesProvider):
"""Builtin flow variables provider.
Provide all flows by variables "${dbgpt.core.flow.flows}"
"""
name = BUILTIN_VARIABLES_CORE_FLOWS
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
service: Service = get_service()
page_result = service.get_list_by_page(
{
"user_name": user_name,
"sys_code": sys_code,
},
1,
1000,
)
flows: List[ServerResponse] = page_result.items
variables = []
for flow in flows:
variables.append(
StorageVariables(
key=key,
name=flow.name,
label=flow.label,
value=flow.uid,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
class BuiltinNodeVariablesProvider(BuiltinVariablesProvider):
"""Builtin node variables provider.
Provide all nodes by variables "${dbgpt.core.flow.nodes}"
"""
name = BUILTIN_VARIABLES_CORE_FLOW_NODES
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
metadata_list = _OPERATOR_REGISTRY.metadata_list()
variables = []
for metadata in metadata_list:
variables.append(
StorageVariables(
key=key,
name=metadata["name"],
label=metadata["label"],
value=metadata["id"],
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
class BuiltinAllVariablesProvider(BuiltinVariablesProvider):
"""Builtin all variables provider.
Provide all variables by variables "${dbgpt.core.variables}"
"""
name = BUILTIN_VARIABLES_CORE_VARIABLES
def _get_variables_from_db(
self,
key: str,
scope: str,
scope_key: Optional[str],
sys_code: Optional[str],
user_name: Optional[str],
category: Literal["common", "secret"] = "common",
) -> List[StorageVariables]:
storage_variables = get_variable_service().list_all_variables(category)
variables = []
for var in storage_variables:
variables.append(
StorageVariables(
key=key,
name=var.name,
label=var.label,
value=var.value,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables.
TODO: Return all builtin variables
"""
return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name)
class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider):
"""Builtin all secret variables provider.
Provide all secret variables by variables "${dbgpt.core.secrets}"
"""
name = BUILTIN_VARIABLES_CORE_SECRETS
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return self._get_variables_from_db(
key, scope, scope_key, sys_code, user_name, "secret"
)
class BuiltinLLMVariablesProvider(BuiltinVariablesProvider):
"""Builtin LLM variables provider.
Provide all LLM variables by variables "${dbgpt.core.llmv}"
"""
name = BUILTIN_VARIABLES_CORE_LLMS
def support_async(self) -> bool:
"""Whether the dynamic options support async."""
return True
async def _get_models(
self,
key: str,
scope: str,
scope_key: Optional[str],
sys_code: Optional[str],
user_name: Optional[str],
expect_worker_type: str = "llm",
) -> List[StorageVariables]:
from dbgpt.model.cluster.controller.controller import BaseModelController
controller = BaseModelController.get_instance(self.system_app)
models = await controller.get_all_instances(healthy_only=True)
model_dict = {}
for model in models:
worker_name, worker_type = model.model_name.split("@")
if expect_worker_type == worker_type:
model_dict[worker_name] = model
variables = []
for worker_name, model in model_dict.items():
variables.append(
StorageVariables(
key=key,
name=worker_name,
label=worker_name,
value=worker_name,
scope=scope,
scope_key=scope_key,
sys_code=sys_code,
user_name=user_name,
)
)
return variables
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return await self._get_models(key, scope, scope_key, sys_code, user_name)
def get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
raise NotImplementedError(
"Not implemented get variables sync, please use async_get_variables"
)
class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider):
"""Builtin embeddings variables provider.
Provide all embeddings variables by variables "${dbgpt.core.embeddings}"
"""
name = BUILTIN_VARIABLES_CORE_EMBEDDINGS
async def async_get_variables(
self,
key: str,
scope: str = "global",
scope_key: Optional[str] = None,
sys_code: Optional[str] = None,
user_name: Optional[str] = None,
) -> List[StorageVariables]:
"""Get the builtin variables."""
return await self._get_models(
key, scope, scope_key, sys_code, user_name, "text2vec"
)

View File

@@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow"
SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow"
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow."
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service"
# Database table name
SERVER_APP_TABLE_NAME = "dbgpt_serve_flow"
SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables"
@dataclass
@@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig):
load_dbgpts_interval: int = field(
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
)
encrypt_key: Optional[str] = field(
default=None, metadata={"help": "The key to encrypt the data"}
)

View File

@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import model_to_dict
from dbgpt.core.awel.flow.flow_factory import State
from dbgpt.core.interface.variables import StorageVariablesProvider
from dbgpt.storage.metadata import BaseDao, Model
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
from ..api.schemas import (
ServeRequest,
ServerResponse,
VariablesRequest,
VariablesResponse,
)
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
class ServeEntity(Model):
@@ -43,6 +49,7 @@ class ServeEntity(Model):
editable = Column(
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
)
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
@@ -74,6 +81,57 @@ class ServeEntity(Model):
return editable is None or editable == 0
class VariablesEntity(Model):
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
key = Column(String(128), index=True, nullable=False, comment="Variable key")
name = Column(String(128), index=True, nullable=True, comment="Variable name")
label = Column(String(128), nullable=True, comment="Variable label")
value = Column(Text, nullable=True, comment="Variable value, JSON format")
value_type = Column(
String(32),
nullable=True,
comment="Variable value type(string, int, float, bool)",
)
category = Column(
String(32),
default="common",
nullable=True,
comment="Variable category(common or secret)",
)
encryption_method = Column(
String(32),
nullable=True,
comment="Variable encryption method(fernet, simple, rsa, aes)",
)
salt = Column(String(128), nullable=True, comment="Variable salt")
scope = Column(
String(32),
default="global",
nullable=True,
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
"etc)",
)
scope_key = Column(
String(256),
nullable=True,
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
"the scope_key is dag id of flow",
)
enabled = Column(
Integer,
default=1,
nullable=True,
comment="Variable enabled, 0: disabled, 1: enabled",
)
description = Column(Text, nullable=True, comment="Variable description")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Flow"""
@@ -98,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
error_message = request_dict.get("error_message")
if error_message:
error_message = error_message[:500]
variables_raw = request_dict.get("variables")
variables = (
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
)
new_dict = {
"uid": request_dict.get("uid"),
"dag_id": request_dict.get("dag_id"),
@@ -113,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"define_type": request_dict.get("define_type"),
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
"description": request_dict.get("description"),
"variables": variables,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
}
@@ -129,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
REQ: The request
"""
flow_data = json.loads(entity.flow_data)
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServeRequest(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -144,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
define_type=entity.define_type,
editable=ServeEntity.to_bool_editable(entity.editable),
description=entity.description,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
@@ -160,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
flow_data = json.loads(entity.flow_data)
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
variables_raw = json.loads(entity.variables) if entity.variables else None
variables = ServeRequest.parse_variables(variables_raw)
return ServerResponse(
uid=entity.uid,
dag_id=entity.dag_id,
@@ -175,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
version=entity.version,
editable=ServeEntity.to_bool_editable(entity.editable),
define_type=entity.define_type,
variables=variables,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
@@ -215,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
entry.editable = ServeEntity.parse_editable(update_request.editable)
if update_request.define_type:
entry.define_type = update_request.define_type
if update_request.variables:
variables_raw = update_request.get_variables_dict()
entry.variables = (
json.dumps(variables_raw, ensure_ascii=False)
if variables_raw
else None
)
if update_request.user_name:
entry.user_name = update_request.user_name
if update_request.sys_code:
@@ -222,3 +300,111 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
session.merge(entry)
session.commit()
return self.get_one(query_request)
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
"""The DAO class for Variables"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(
self, request: Union[VariablesRequest, Dict[str, Any]]
) -> VariablesEntity:
"""Convert the request to an entity
Args:
request (Union[VariablesRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
model_to_dict(request) if isinstance(request, VariablesRequest) else request
)
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
enabled = 1 if request_dict.get("enabled", True) else 0
new_dict = {
"key": request_dict.get("key"),
"name": request_dict.get("name"),
"label": request_dict.get("label"),
"value": value,
"value_type": request_dict.get("value_type"),
"category": request_dict.get("category"),
"encryption_method": request_dict.get("encryption_method"),
"salt": request_dict.get("salt"),
"scope": request_dict.get("scope"),
"scope_key": request_dict.get("scope_key"),
"enabled": enabled,
"user_name": request_dict.get("user_name"),
"sys_code": request_dict.get("sys_code"),
"description": request_dict.get("description"),
}
entity = VariablesEntity(**new_dict)
return entity
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
enabled = entity.enabled == 1
return VariablesRequest(
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
description=entity.description,
)
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
value = StorageVariablesProvider.deserialize_value(entity.value)
if entity.category == "secret":
value = "******"
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
enabled = entity.enabled == 1
return VariablesResponse(
id=entity.id,
key=entity.key,
name=entity.name,
label=entity.label,
value=value,
value_type=entity.value_type,
category=entity.category,
encryption_method=entity.encryption_method,
salt=entity.salt,
scope=entity.scope,
scope_key=entity.scope_key,
enabled=enabled,
user_name=entity.user_name,
sys_code=entity.sys_code,
gmt_created=gmt_created_str,
gmt_modified=gmt_modified_str,
description=entity.description,
)

View File

@@ -0,0 +1,71 @@
from typing import Type
from sqlalchemy.orm import Session
from dbgpt.core.interface.storage import StorageItemAdapter
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
from .models import VariablesEntity
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
"""Variables adapter.
Convert between storage format and database model.
"""
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
"""Convert to storage format."""
return VariablesEntity(
key=item.key,
name=item.name,
label=item.label,
value=item.value,
value_type=item.value_type,
category=item.category,
encryption_method=item.encryption_method,
salt=item.salt,
scope=item.scope,
scope_key=item.scope_key,
sys_code=item.sys_code,
user_name=item.user_name,
description=item.description,
)
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
"""Convert from storage format."""
return StorageVariables(
key=model.key,
name=model.name,
label=model.label,
value=model.value,
value_type=model.value_type,
category=model.category,
encryption_method=model.encryption_method,
salt=model.salt,
scope=model.scope,
scope_key=model.scope_key,
sys_code=model.sys_code,
user_name=model.user_name,
description=model.description,
)
def get_query_for_identifier(
self,
storage_format: Type[VariablesEntity],
resource_id: VariablesIdentifier,
**kwargs,
):
"""Get query for identifier."""
session: Session = kwargs.get("session")
if session is None:
raise Exception("session is None")
query_obj = session.query(VariablesEntity)
for key, value in resource_id.to_dict().items():
if value is None:
continue
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
# enabled must be True
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
return query_obj

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Union
from sqlalchemy import URL
from dbgpt.component import SystemApp
from dbgpt.core.interface.variables import VariablesProvider
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
@@ -40,6 +41,8 @@ class Serve(BaseServe):
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
)
self._db_manager: Optional[DatabaseManager] = None
self._variables_provider: Optional[VariablesProvider] = None
self._serve_config: Optional[ServeConfig] = None
def init_app(self, system_app: SystemApp):
if self._app_has_initiated:
@@ -62,5 +65,37 @@ class Serve(BaseServe):
def before_start(self):
"""Called before the start of the application."""
# TODO: Your code here
from dbgpt.core.interface.variables import (
FernetEncryption,
StorageVariablesProvider,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
from .models.models import ServeEntity, VariablesEntity
from .models.variables_adapter import VariablesAdapter
self._db_manager = self.create_or_get_db_manager()
self._serve_config = ServeConfig.from_app_config(
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._db_manager = self.create_or_get_db_manager()
storage_adapter = VariablesAdapter()
serializer = JsonSerializer()
storage = SQLAlchemyStorage(
self._db_manager,
VariablesEntity,
storage_adapter,
serializer,
)
self._variables_provider = StorageVariablesProvider(
storage=storage,
encryption=FernetEncryption(self._serve_config.encrypt_key),
system_app=self._system_app,
)
@property
def variables_provider(self):
"""Get the variables provider of the serve app with db storage"""
return self._variables_provider

View File

@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
from dbgpt.core.awel.flow.flow_factory import (
FlowCategory,
FlowFactory,
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
raise ValueError(
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
) from e
res = self.dao.create(request)
self.dao.create(request)
# Query from database
res = self.get({"uid": request.uid})
state = request.state
try:
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON
async def debug_flow(
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
) -> AsyncIterator[ModelOutput]:
"""Debug the flow.
Args:
request (FlowDebugRequest): The request
default_incremental (Optional[bool]): The default incremental configuration
Returns:
AsyncIterator[ModelOutput]: The output
"""
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = self._flow_factory.build(request.flow)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag")
task = cast(BaseOperator, leaf_nodes[0])
dag_metadata = _parse_metadata(dag)
# TODO: Run task with variables
variables = request.variables
dag_request = request.request
if isinstance(request.request, CommonLLMHttpRequestBody):
incremental = request.request.incremental
elif isinstance(request.request, dict):
incremental = request.request.get("incremental", False)
else:
raise ValueError("Invalid request type")
if default_incremental is not None:
incremental = default_incremental
try:
async for output in safe_chat_stream_with_dag_task(
task, dag_request, incremental
):
yield output
except HTTPException as e:
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
except Exception as e:
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
async def _wrapper_chat_stream_flow_str(
self, stream_iter: AsyncIterator[ModelOutput]
) -> AsyncIterator[str]:
async for output in stream_iter:
text = output.text
if text:
text = text.replace("\n", "\\n")
if output.error_code != 0:
yield f"data:[SERVER_ERROR]{text}\n\n"
break
else:
yield f"data:{text}\n\n"

View File

@@ -0,0 +1,121 @@
import io
import json
import os
import tempfile
import zipfile
import aiofiles
import tomlkit
from fastapi import UploadFile
from dbgpt.component import SystemApp
from dbgpt.serve.core import blocking_func_to_async
from ..api.schemas import ServeRequest
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
zip_buffer = io.BytesIO()
flow_name = flow.name
flow_label = flow.label
flow_description = flow.description
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
readme = f"# {flow_label}\n\n{flow_description}"
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
zip_file.writestr(f"{package_name}/README.md", readme)
zip_file.writestr(
f"{package_name}/{flow_name}/__init__.py",
"",
)
zip_file.writestr(
f"{package_name}/{flow_name}/definition/flow_definition.json",
dag_json,
)
dbgpts_toml = tomlkit.document()
# Add flow information
dbgpts_flow_toml = tomlkit.document()
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
name_with_comment.comment("A unique name for all dbgpts")
dbgpts_flow_toml.add("name", name_with_comment)
dbgpts_flow_toml.add("version", "0.1.0")
dbgpts_flow_toml.add(
"description",
flow_description,
)
dbgpts_flow_toml.add("authors", [])
definition_type_with_comment = tomlkit.string("json")
definition_type_with_comment.comment("How to define the flow, python or json")
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
dbgpts_toml.add("flow", dbgpts_flow_toml)
# Add python and json config
python_config = tomlkit.table()
dbgpts_toml.add("python_config", python_config)
json_config = tomlkit.table()
json_config.add("file_path", "definition/flow_definition.json")
json_config.comment("Json config")
dbgpts_toml.add("json_config", json_config)
# Transform to string
toml_string = tomlkit.dumps(dbgpts_toml)
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
pyproject_toml = tomlkit.document()
# Add [tool.poetry] section
tool_poetry_toml = tomlkit.table()
tool_poetry_toml.add("name", package_name)
tool_poetry_toml.add("version", "0.1.0")
tool_poetry_toml.add("description", "A dbgpts package")
tool_poetry_toml.add("authors", [])
tool_poetry_toml.add("readme", "README.md")
pyproject_toml["tool"] = tomlkit.table()
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
# Add [tool.poetry.dependencies] section
dependencies = tomlkit.table()
dependencies.add("python", "^3.10")
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
# Add [build-system] section
build_system = tomlkit.table()
build_system.add("requires", ["poetry-core"])
build_system.add("build-backend", "poetry.core.masonry.api")
pyproject_toml["build-system"] = build_system
# Transform to string
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
zip_buffer.seek(0)
return zip_buffer
async def _parse_flow_from_zip_file(
file: UploadFile, sys_app: SystemApp
) -> ServeRequest:
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
filename = file.filename
if not filename.endswith(".zip"):
raise ValueError("Uploaded file must be a ZIP file")
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, filename)
# Save uploaded file to temporary directory
async with aiofiles.open(zip_path, "wb") as out_file:
while content := await file.read(1024 * 64): # Read in chunks of 64KB
await out_file.write(content)
flow = await blocking_func_to_async(
sys_app, _load_flow_package_from_zip_path, zip_path
)
return flow

View File

@@ -0,0 +1,152 @@
from typing import List, Optional
from dbgpt import SystemApp
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
from dbgpt.serve.core import BaseService
from ..api.schemas import VariablesRequest, VariablesResponse
from ..config import (
SERVE_CONFIG_KEY_PREFIX,
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
ServeConfig,
)
from ..models.models import VariablesDao, VariablesEntity
class VariablesService(
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
):
"""Variables service"""
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: VariablesDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
"""Initialize the service
Args:
system_app (SystemApp): The system app
"""
super().init_app(system_app)
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = self._dao or VariablesDao(self._serve_config)
self._system_app = system_app
@property
def dao(self) -> VariablesDao:
"""Returns the internal DAO."""
return self._dao
@property
def variables_provider(self) -> VariablesProvider:
"""Returns the internal VariablesProvider.
Returns:
VariablesProvider: The internal VariablesProvider
"""
variables_provider = VariablesProvider.get_instance(
self._system_app, default_component=None
)
if variables_provider:
return variables_provider
else:
from ..serve import Serve
variables_provider = Serve.get_instance(self._system_app).variables_provider
self._system_app.register_instance(variables_provider)
return variables_provider
@property
def config(self) -> ServeConfig:
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: VariablesRequest) -> VariablesResponse:
"""Create a new entity
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
"""Update variables.
Args:
request (VariablesRequest): The request
Returns:
VariablesResponse: The response
"""
variables = StorageVariables(
key=request.key,
name=request.name,
label=request.label,
value=request.value,
value_type=request.value_type,
category=request.category,
scope=request.scope,
scope_key=request.scope_key,
user_name=request.user_name,
sys_code=request.sys_code,
enabled=1 if request.enabled else 0,
description=request.description,
)
exist_value = self.variables_provider.get(
variables.identifier.str_identifier, None
)
if exist_value is None:
raise ValueError(
f"Variable {variables.identifier.str_identifier} not found"
)
self.variables_provider.save(variables)
query = {
"key": request.key,
"name": request.name,
"scope": request.scope,
"scope_key": request.scope_key,
"sys_code": request.sys_code,
"user_name": request.user_name,
"enabled": request.enabled,
}
return self.dao.get_one(query)
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
"""List all variables."""
return self.dao.get_list({"enabled": True, "category": category})