mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
feat(core): AWEL flow 2.0 backend code (#1879)
Co-authored-by: yhjun1026 <460342015@qq.com>
This commit is contained in:
@@ -1,18 +1,29 @@
|
||||
import io
|
||||
import json
|
||||
from functools import cache
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from .schemas import ServeRequest, ServerResponse
|
||||
from ..service.variables_service import VariablesService
|
||||
from .schemas import (
|
||||
FlowDebugRequest,
|
||||
RefreshNodeRequest,
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -23,7 +34,12 @@ global_system_app: Optional[SystemApp] = None
|
||||
|
||||
def get_service() -> Service:
|
||||
"""Get the service instance"""
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
return Service.get_instance(global_system_app)
|
||||
|
||||
|
||||
def get_variable_service() -> VariablesService:
|
||||
"""Get the service instance"""
|
||||
return VariablesService.get_instance(global_system_app)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
@@ -102,7 +118,9 @@ async def test_auth():
|
||||
|
||||
|
||||
@router.post(
|
||||
"/flows", response_model=Result[None], dependencies=[Depends(check_api_key)]
|
||||
"/flows",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
@@ -239,20 +257,236 @@ async def query_page(
|
||||
|
||||
|
||||
@router.get("/nodes", dependencies=[Depends(check_api_key)])
|
||||
async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]:
|
||||
async def get_nodes(
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
tags: Optional[str] = Query(default=None, description="tags"),
|
||||
):
|
||||
"""Get the operator or resource nodes
|
||||
|
||||
Args:
|
||||
user_name (Optional[str]): The username
|
||||
sys_code (Optional[str]): The system code
|
||||
tags (Optional[str]): The tags encoded in JSON format
|
||||
|
||||
Returns:
|
||||
Result[List[Union[ViewMetadata, ResourceMetadata]]]:
|
||||
The operator or resource nodes
|
||||
"""
|
||||
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
||||
|
||||
return Result.succ(_OPERATOR_REGISTRY.metadata_list())
|
||||
tags_dict: Optional[Dict[str, str]] = None
|
||||
if tags:
|
||||
try:
|
||||
tags_dict = json.loads(tags)
|
||||
except json.JSONDecodeError:
|
||||
return Result.fail("Invalid JSON format for tags")
|
||||
|
||||
metadata_list = await blocking_func_to_async(
|
||||
global_system_app,
|
||||
_OPERATOR_REGISTRY.metadata_list,
|
||||
tags_dict,
|
||||
user_name,
|
||||
sys_code,
|
||||
)
|
||||
return Result.succ(metadata_list)
|
||||
|
||||
|
||||
@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)])
|
||||
async def refresh_nodes(refresh_request: RefreshNodeRequest):
|
||||
"""Refresh the operator or resource nodes
|
||||
|
||||
Returns:
|
||||
Result[None]: The response
|
||||
"""
|
||||
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
||||
|
||||
# Make sure the variables provider is initialized
|
||||
_ = get_variable_service().variables_provider
|
||||
|
||||
new_metadata = await _OPERATOR_REGISTRY.refresh(
|
||||
refresh_request.id,
|
||||
refresh_request.flow_type == "operator",
|
||||
refresh_request.refresh,
|
||||
"http",
|
||||
global_system_app,
|
||||
)
|
||||
return Result.succ(new_metadata)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/variables",
|
||||
response_model=Result[VariablesResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def create_variables(
|
||||
variables_request: VariablesRequest,
|
||||
) -> Result[VariablesResponse]:
|
||||
"""Create a new Variables entity
|
||||
|
||||
Args:
|
||||
variables_request (VariablesRequest): The request
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, get_variable_service().create, variables_request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/variables/{v_id}",
|
||||
response_model=Result[VariablesResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def update_variables(
|
||||
v_id: int, variables_request: VariablesRequest
|
||||
) -> Result[VariablesResponse]:
|
||||
"""Update a Variables entity
|
||||
|
||||
Args:
|
||||
v_id (int): The variable id
|
||||
variables_request (VariablesRequest): The request
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, get_variable_service().update, v_id, variables_request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
|
||||
async def debug_flow(
|
||||
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
|
||||
):
|
||||
"""Run the flow in debug mode."""
|
||||
# Return the no-incremental stream by default
|
||||
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
service._wrapper_chat_stream_flow_str(stream_iter),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
|
||||
async def export_flow(
|
||||
uid: str,
|
||||
export_type: Literal["json", "dbgpts"] = Query(
|
||||
"json", description="export type(json or dbgpts)"
|
||||
),
|
||||
format: Literal["file", "json"] = Query(
|
||||
"file", description="response format(file or json)"
|
||||
),
|
||||
file_name: Optional[str] = Query(default=None, description="file name to export"),
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Export the flow to a file."""
|
||||
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
||||
package_name = flow.name.replace("_", "-")
|
||||
file_name = file_name or package_name
|
||||
if export_type == "json":
|
||||
flow_dict = {"flow": flow.to_dict()}
|
||||
if format == "json":
|
||||
return JSONResponse(content=flow_dict)
|
||||
else:
|
||||
# Return the json file
|
||||
return StreamingResponse(
|
||||
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
|
||||
media_type="application/file",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment;filename={file_name}.json"
|
||||
},
|
||||
)
|
||||
|
||||
elif export_type == "dbgpts":
|
||||
from ..service.share_utils import _generate_dbgpts_zip
|
||||
|
||||
if format == "json":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="json response is not supported for dbgpts"
|
||||
)
|
||||
|
||||
zip_buffer = await blocking_func_to_async(
|
||||
global_system_app, _generate_dbgpts_zip, package_name, flow
|
||||
)
|
||||
return StreamingResponse(
|
||||
zip_buffer,
|
||||
media_type="application/x-zip-compressed",
|
||||
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/flow/import",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def import_flow(
|
||||
file: UploadFile = File(...),
|
||||
save_flow: bool = Query(
|
||||
False, description="Whether to save the flow after importing"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Import the flow from a file."""
|
||||
filename = file.filename
|
||||
file_extension = filename.split(".")[-1].lower()
|
||||
if file_extension == "json":
|
||||
# Handle json file
|
||||
json_content = await file.read()
|
||||
json_dict = json.loads(json_content)
|
||||
if "flow" not in json_dict:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="invalid json file, missing 'flow' key"
|
||||
)
|
||||
flow = ServeRequest.parse_obj(json_dict["flow"])
|
||||
elif file_extension == "zip":
|
||||
from ..service.share_utils import _parse_flow_from_zip_file
|
||||
|
||||
# Handle zip file
|
||||
flow = await _parse_flow_from_zip_file(file, global_system_app)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"invalid file extension {file_extension}"
|
||||
)
|
||||
if save_flow:
|
||||
return Result.succ(service.create_and_save_dag(flow))
|
||||
else:
|
||||
return Result.succ(flow)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
from .variables_provider import (
|
||||
BuiltinAllSecretVariablesProvider,
|
||||
BuiltinAllVariablesProvider,
|
||||
BuiltinEmbeddingsVariablesProvider,
|
||||
BuiltinFlowVariablesProvider,
|
||||
BuiltinLLMVariablesProvider,
|
||||
BuiltinNodeVariablesProvider,
|
||||
)
|
||||
|
||||
global global_system_app
|
||||
system_app.register(Service)
|
||||
system_app.register(VariablesService)
|
||||
system_app.register(BuiltinFlowVariablesProvider)
|
||||
system_app.register(BuiltinNodeVariablesProvider)
|
||||
system_app.register(BuiltinAllVariablesProvider)
|
||||
system_app.register(BuiltinAllSecretVariablesProvider)
|
||||
system_app.register(BuiltinLLMVariablesProvider)
|
||||
system_app.register(BuiltinEmbeddingsVariablesProvider)
|
||||
global_system_app = system_app
|
||||
|
@@ -1,7 +1,9 @@
|
||||
from dbgpt._private.pydantic import ConfigDict
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# Define your Pydantic schemas here
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest
|
||||
from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest
|
||||
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
@@ -14,3 +16,71 @@ class ServerResponse(FlowPanel):
|
||||
# TODO define your own fields here
|
||||
|
||||
model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}")
|
||||
|
||||
|
||||
class VariablesResponse(VariablesRequest):
|
||||
"""Variable response model."""
|
||||
|
||||
id: int = Field(
|
||||
...,
|
||||
description="The id of the variable",
|
||||
examples=[1],
|
||||
)
|
||||
|
||||
|
||||
class RefreshNodeRequest(BaseModel):
|
||||
"""Flow response model"""
|
||||
|
||||
model_config = ConfigDict(title=f"RefreshNodeRequest")
|
||||
id: str = Field(
|
||||
...,
|
||||
title="The id of the node",
|
||||
description="The id of the node to refresh",
|
||||
examples=["operator_llm_operator___$$___llm___$$___v1"],
|
||||
)
|
||||
flow_type: Literal["operator", "resource"] = Field(
|
||||
"operator",
|
||||
title="The type of the node",
|
||||
description="The type of the node to refresh",
|
||||
examples=["operator", "resource"],
|
||||
)
|
||||
type_name: str = Field(
|
||||
...,
|
||||
title="The type of the node",
|
||||
description="The type of the node to refresh",
|
||||
examples=["LLMOperator"],
|
||||
)
|
||||
type_cls: str = Field(
|
||||
...,
|
||||
title="The class of the node",
|
||||
description="The class of the node to refresh",
|
||||
examples=["dbgpt.core.operator.llm.LLMOperator"],
|
||||
)
|
||||
refresh: List[RefreshOptionRequest] = Field(
|
||||
...,
|
||||
title="The refresh options",
|
||||
description="The refresh options",
|
||||
)
|
||||
|
||||
|
||||
class FlowDebugRequest(BaseModel):
|
||||
"""Flow response model"""
|
||||
|
||||
model_config = ConfigDict(title=f"FlowDebugRequest")
|
||||
flow: ServeRequest = Field(
|
||||
...,
|
||||
title="The flow to debug",
|
||||
description="The flow to debug",
|
||||
)
|
||||
request: Union[CommonLLMHttpRequestBody, Dict[str, Any]] = Field(
|
||||
...,
|
||||
title="The request to debug",
|
||||
description="The request to debug",
|
||||
)
|
||||
variables: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
title="The variables to debug",
|
||||
description="The variables to debug",
|
||||
)
|
||||
user_name: Optional[str] = Field(None, description="User name")
|
||||
sys_code: Optional[str] = Field(None, description="System code")
|
||||
|
260
dbgpt/serve/flow/api/variables_provider.py
Normal file
260
dbgpt/serve/flow/api/variables_provider.py
Normal file
@@ -0,0 +1,260 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from dbgpt.core.interface.variables import (
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS,
|
||||
BUILTIN_VARIABLES_CORE_FLOW_NODES,
|
||||
BUILTIN_VARIABLES_CORE_FLOWS,
|
||||
BUILTIN_VARIABLES_CORE_LLMS,
|
||||
BUILTIN_VARIABLES_CORE_SECRETS,
|
||||
BUILTIN_VARIABLES_CORE_VARIABLES,
|
||||
BuiltinVariablesProvider,
|
||||
StorageVariables,
|
||||
)
|
||||
|
||||
from ..service.service import Service
|
||||
from .endpoints import get_service, get_variable_service
|
||||
from .schemas import ServerResponse
|
||||
|
||||
|
||||
class BuiltinFlowVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin flow variables provider.
|
||||
|
||||
Provide all flows by variables "${dbgpt.core.flow.flows}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_FLOWS
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
service: Service = get_service()
|
||||
page_result = service.get_list_by_page(
|
||||
{
|
||||
"user_name": user_name,
|
||||
"sys_code": sys_code,
|
||||
},
|
||||
1,
|
||||
1000,
|
||||
)
|
||||
flows: List[ServerResponse] = page_result.items
|
||||
variables = []
|
||||
for flow in flows:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=flow.name,
|
||||
label=flow.label,
|
||||
value=flow.uid,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinNodeVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin node variables provider.
|
||||
|
||||
Provide all nodes by variables "${dbgpt.core.flow.nodes}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_FLOW_NODES
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
||||
|
||||
metadata_list = _OPERATOR_REGISTRY.metadata_list()
|
||||
variables = []
|
||||
for metadata in metadata_list:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=metadata["name"],
|
||||
label=metadata["label"],
|
||||
value=metadata["id"],
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
|
||||
class BuiltinAllVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin all variables provider.
|
||||
|
||||
Provide all variables by variables "${dbgpt.core.variables}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_VARIABLES
|
||||
|
||||
def _get_variables_from_db(
|
||||
self,
|
||||
key: str,
|
||||
scope: str,
|
||||
scope_key: Optional[str],
|
||||
sys_code: Optional[str],
|
||||
user_name: Optional[str],
|
||||
category: Literal["common", "secret"] = "common",
|
||||
) -> List[StorageVariables]:
|
||||
storage_variables = get_variable_service().list_all_variables(category)
|
||||
variables = []
|
||||
for var in storage_variables:
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=var.name,
|
||||
label=var.label,
|
||||
value=var.value,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables.
|
||||
|
||||
TODO: Return all builtin variables
|
||||
"""
|
||||
return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name)
|
||||
|
||||
|
||||
class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider):
|
||||
"""Builtin all secret variables provider.
|
||||
|
||||
Provide all secret variables by variables "${dbgpt.core.secrets}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_SECRETS
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return self._get_variables_from_db(
|
||||
key, scope, scope_key, sys_code, user_name, "secret"
|
||||
)
|
||||
|
||||
|
||||
class BuiltinLLMVariablesProvider(BuiltinVariablesProvider):
|
||||
"""Builtin LLM variables provider.
|
||||
|
||||
Provide all LLM variables by variables "${dbgpt.core.llmv}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_LLMS
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the dynamic options support async."""
|
||||
return True
|
||||
|
||||
async def _get_models(
|
||||
self,
|
||||
key: str,
|
||||
scope: str,
|
||||
scope_key: Optional[str],
|
||||
sys_code: Optional[str],
|
||||
user_name: Optional[str],
|
||||
expect_worker_type: str = "llm",
|
||||
) -> List[StorageVariables]:
|
||||
from dbgpt.model.cluster.controller.controller import BaseModelController
|
||||
|
||||
controller = BaseModelController.get_instance(self.system_app)
|
||||
models = await controller.get_all_instances(healthy_only=True)
|
||||
model_dict = {}
|
||||
for model in models:
|
||||
worker_name, worker_type = model.model_name.split("@")
|
||||
if expect_worker_type == worker_type:
|
||||
model_dict[worker_name] = model
|
||||
variables = []
|
||||
for worker_name, model in model_dict.items():
|
||||
variables.append(
|
||||
StorageVariables(
|
||||
key=key,
|
||||
name=worker_name,
|
||||
label=worker_name,
|
||||
value=worker_name,
|
||||
scope=scope,
|
||||
scope_key=scope_key,
|
||||
sys_code=sys_code,
|
||||
user_name=user_name,
|
||||
)
|
||||
)
|
||||
return variables
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return await self._get_models(key, scope, scope_key, sys_code, user_name)
|
||||
|
||||
def get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
raise NotImplementedError(
|
||||
"Not implemented get variables sync, please use async_get_variables"
|
||||
)
|
||||
|
||||
|
||||
class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider):
|
||||
"""Builtin embeddings variables provider.
|
||||
|
||||
Provide all embeddings variables by variables "${dbgpt.core.embeddings}"
|
||||
"""
|
||||
|
||||
name = BUILTIN_VARIABLES_CORE_EMBEDDINGS
|
||||
|
||||
async def async_get_variables(
|
||||
self,
|
||||
key: str,
|
||||
scope: str = "global",
|
||||
scope_key: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
) -> List[StorageVariables]:
|
||||
"""Get the builtin variables."""
|
||||
return await self._get_models(
|
||||
key, scope, scope_key, sys_code, user_name, "text2vec"
|
||||
)
|
@@ -8,8 +8,10 @@ SERVE_APP_NAME = "dbgpt_serve_flow"
|
||||
SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow"
|
||||
SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow."
|
||||
SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service"
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service"
|
||||
# Database table name
|
||||
SERVER_APP_TABLE_NAME = "dbgpt_serve_flow"
|
||||
SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig):
|
||||
load_dbgpts_interval: int = field(
|
||||
default=5, metadata={"help": "Interval to load dbgpts from installed packages"}
|
||||
)
|
||||
encrypt_key: Optional[str] = field(
|
||||
default=None, metadata={"help": "The key to encrypt the data"}
|
||||
)
|
||||
|
@@ -10,11 +10,17 @@ from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt._private.pydantic import model_to_dict
|
||||
from dbgpt.core.awel.flow.flow_factory import State
|
||||
from dbgpt.core.interface.variables import StorageVariablesProvider
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
|
||||
from ..api.schemas import (
|
||||
ServeRequest,
|
||||
ServerResponse,
|
||||
VariablesRequest,
|
||||
VariablesResponse,
|
||||
)
|
||||
from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig
|
||||
|
||||
|
||||
class ServeEntity(Model):
|
||||
@@ -43,6 +49,7 @@ class ServeEntity(Model):
|
||||
editable = Column(
|
||||
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
|
||||
)
|
||||
variables = Column(Text, nullable=True, comment="Flow variables, JSON format")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
@@ -74,6 +81,57 @@ class ServeEntity(Model):
|
||||
return editable is None or editable == 0
|
||||
|
||||
|
||||
class VariablesEntity(Model):
|
||||
__tablename__ = SERVER_APP_VARIABLES_TABLE_NAME
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="Auto increment id")
|
||||
key = Column(String(128), index=True, nullable=False, comment="Variable key")
|
||||
name = Column(String(128), index=True, nullable=True, comment="Variable name")
|
||||
label = Column(String(128), nullable=True, comment="Variable label")
|
||||
value = Column(Text, nullable=True, comment="Variable value, JSON format")
|
||||
value_type = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable value type(string, int, float, bool)",
|
||||
)
|
||||
category = Column(
|
||||
String(32),
|
||||
default="common",
|
||||
nullable=True,
|
||||
comment="Variable category(common or secret)",
|
||||
)
|
||||
encryption_method = Column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
comment="Variable encryption method(fernet, simple, rsa, aes)",
|
||||
)
|
||||
salt = Column(String(128), nullable=True, comment="Variable salt")
|
||||
scope = Column(
|
||||
String(32),
|
||||
default="global",
|
||||
nullable=True,
|
||||
comment="Variable scope(global,flow,app,agent,datasource,flow_priv,agent_priv, "
|
||||
"etc)",
|
||||
)
|
||||
scope_key = Column(
|
||||
String(256),
|
||||
nullable=True,
|
||||
comment="Variable scope key, default is empty, for scope is 'flow_priv', "
|
||||
"the scope_key is dag id of flow",
|
||||
)
|
||||
enabled = Column(
|
||||
Integer,
|
||||
default=1,
|
||||
nullable=True,
|
||||
comment="Variable enabled, 0: disabled, 1: enabled",
|
||||
)
|
||||
description = Column(Text, nullable=True, comment="Variable description")
|
||||
user_name = Column(String(128), index=True, nullable=True, comment="User name")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
|
||||
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
|
||||
|
||||
|
||||
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""The DAO class for Flow"""
|
||||
|
||||
@@ -98,6 +156,11 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
error_message = request_dict.get("error_message")
|
||||
if error_message:
|
||||
error_message = error_message[:500]
|
||||
|
||||
variables_raw = request_dict.get("variables")
|
||||
variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False) if variables_raw else None
|
||||
)
|
||||
new_dict = {
|
||||
"uid": request_dict.get("uid"),
|
||||
"dag_id": request_dict.get("dag_id"),
|
||||
@@ -113,6 +176,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"define_type": request_dict.get("define_type"),
|
||||
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
|
||||
"description": request_dict.get("description"),
|
||||
"variables": variables,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
}
|
||||
@@ -129,6 +193,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
REQ: The request
|
||||
"""
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServeRequest(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -144,6 +210,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
define_type=entity.define_type,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
description=entity.description,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
)
|
||||
@@ -160,6 +227,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flow_data = json.loads(entity.flow_data)
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
variables_raw = json.loads(entity.variables) if entity.variables else None
|
||||
variables = ServeRequest.parse_variables(variables_raw)
|
||||
return ServerResponse(
|
||||
uid=entity.uid,
|
||||
dag_id=entity.dag_id,
|
||||
@@ -175,6 +244,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
version=entity.version,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
define_type=entity.define_type,
|
||||
variables=variables,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
@@ -215,6 +285,14 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.define_type:
|
||||
entry.define_type = update_request.define_type
|
||||
|
||||
if update_request.variables:
|
||||
variables_raw = update_request.get_variables_dict()
|
||||
entry.variables = (
|
||||
json.dumps(variables_raw, ensure_ascii=False)
|
||||
if variables_raw
|
||||
else None
|
||||
)
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
@@ -222,3 +300,111 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
session.merge(entry)
|
||||
session.commit()
|
||||
return self.get_one(query_request)
|
||||
|
||||
|
||||
class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]):
|
||||
"""The DAO class for Variables"""
|
||||
|
||||
def __init__(self, serve_config: ServeConfig):
|
||||
super().__init__()
|
||||
self._serve_config = serve_config
|
||||
|
||||
def from_request(
|
||||
self, request: Union[VariablesRequest, Dict[str, Any]]
|
||||
) -> VariablesEntity:
|
||||
"""Convert the request to an entity
|
||||
|
||||
Args:
|
||||
request (Union[VariablesRequest, Dict[str, Any]]): The request
|
||||
|
||||
Returns:
|
||||
T: The entity
|
||||
"""
|
||||
request_dict = (
|
||||
model_to_dict(request) if isinstance(request, VariablesRequest) else request
|
||||
)
|
||||
value = StorageVariablesProvider.serialize_value(request_dict.get("value"))
|
||||
enabled = 1 if request_dict.get("enabled", True) else 0
|
||||
new_dict = {
|
||||
"key": request_dict.get("key"),
|
||||
"name": request_dict.get("name"),
|
||||
"label": request_dict.get("label"),
|
||||
"value": value,
|
||||
"value_type": request_dict.get("value_type"),
|
||||
"category": request_dict.get("category"),
|
||||
"encryption_method": request_dict.get("encryption_method"),
|
||||
"salt": request_dict.get("salt"),
|
||||
"scope": request_dict.get("scope"),
|
||||
"scope_key": request_dict.get("scope_key"),
|
||||
"enabled": enabled,
|
||||
"user_name": request_dict.get("user_name"),
|
||||
"sys_code": request_dict.get("sys_code"),
|
||||
"description": request_dict.get("description"),
|
||||
}
|
||||
entity = VariablesEntity(**new_dict)
|
||||
return entity
|
||||
|
||||
def to_request(self, entity: VariablesEntity) -> VariablesRequest:
|
||||
"""Convert the entity to a request
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
REQ: The request
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesRequest(
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
description=entity.description,
|
||||
)
|
||||
|
||||
def to_response(self, entity: VariablesEntity) -> VariablesResponse:
|
||||
"""Convert the entity to a response
|
||||
|
||||
Args:
|
||||
entity (T): The entity
|
||||
|
||||
Returns:
|
||||
RES: The response
|
||||
"""
|
||||
value = StorageVariablesProvider.deserialize_value(entity.value)
|
||||
if entity.category == "secret":
|
||||
value = "******"
|
||||
gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
enabled = entity.enabled == 1
|
||||
return VariablesResponse(
|
||||
id=entity.id,
|
||||
key=entity.key,
|
||||
name=entity.name,
|
||||
label=entity.label,
|
||||
value=value,
|
||||
value_type=entity.value_type,
|
||||
category=entity.category,
|
||||
encryption_method=entity.encryption_method,
|
||||
salt=entity.salt,
|
||||
scope=entity.scope,
|
||||
scope_key=entity.scope_key,
|
||||
enabled=enabled,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
gmt_modified=gmt_modified_str,
|
||||
description=entity.description,
|
||||
)
|
||||
|
71
dbgpt/serve/flow/models/variables_adapter.py
Normal file
71
dbgpt/serve/flow/models/variables_adapter.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dbgpt.core.interface.storage import StorageItemAdapter
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier
|
||||
|
||||
from .models import VariablesEntity
|
||||
|
||||
|
||||
class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]):
|
||||
"""Variables adapter.
|
||||
|
||||
Convert between storage format and database model.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: StorageVariables) -> VariablesEntity:
|
||||
"""Convert to storage format."""
|
||||
return VariablesEntity(
|
||||
key=item.key,
|
||||
name=item.name,
|
||||
label=item.label,
|
||||
value=item.value,
|
||||
value_type=item.value_type,
|
||||
category=item.category,
|
||||
encryption_method=item.encryption_method,
|
||||
salt=item.salt,
|
||||
scope=item.scope,
|
||||
scope_key=item.scope_key,
|
||||
sys_code=item.sys_code,
|
||||
user_name=item.user_name,
|
||||
description=item.description,
|
||||
)
|
||||
|
||||
def from_storage_format(self, model: VariablesEntity) -> StorageVariables:
|
||||
"""Convert from storage format."""
|
||||
return StorageVariables(
|
||||
key=model.key,
|
||||
name=model.name,
|
||||
label=model.label,
|
||||
value=model.value,
|
||||
value_type=model.value_type,
|
||||
category=model.category,
|
||||
encryption_method=model.encryption_method,
|
||||
salt=model.salt,
|
||||
scope=model.scope,
|
||||
scope_key=model.scope_key,
|
||||
sys_code=model.sys_code,
|
||||
user_name=model.user_name,
|
||||
description=model.description,
|
||||
)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[VariablesEntity],
|
||||
resource_id: VariablesIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
"""Get query for identifier."""
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise Exception("session is None")
|
||||
query_obj = session.query(VariablesEntity)
|
||||
for key, value in resource_id.to_dict().items():
|
||||
if value is None:
|
||||
continue
|
||||
query_obj = query_obj.filter(getattr(VariablesEntity, key) == value)
|
||||
|
||||
# enabled must be True
|
||||
query_obj = query_obj.filter(VariablesEntity.enabled == 1)
|
||||
return query_obj
|
@@ -4,6 +4,7 @@ from typing import List, Optional, Union
|
||||
from sqlalchemy import URL
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.interface.variables import VariablesProvider
|
||||
from dbgpt.serve.core import BaseServe
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
|
||||
@@ -40,6 +41,8 @@ class Serve(BaseServe):
|
||||
system_app, api_prefix, api_tags, db_url_or_db, try_create_tables
|
||||
)
|
||||
self._db_manager: Optional[DatabaseManager] = None
|
||||
self._variables_provider: Optional[VariablesProvider] = None
|
||||
self._serve_config: Optional[ServeConfig] = None
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
if self._app_has_initiated:
|
||||
@@ -62,5 +65,37 @@ class Serve(BaseServe):
|
||||
|
||||
def before_start(self):
|
||||
"""Called before the start of the application."""
|
||||
# TODO: Your code here
|
||||
from dbgpt.core.interface.variables import (
|
||||
FernetEncryption,
|
||||
StorageVariablesProvider,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
from .models.models import ServeEntity, VariablesEntity
|
||||
from .models.variables_adapter import VariablesAdapter
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
self._system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
|
||||
self._db_manager = self.create_or_get_db_manager()
|
||||
storage_adapter = VariablesAdapter()
|
||||
serializer = JsonSerializer()
|
||||
storage = SQLAlchemyStorage(
|
||||
self._db_manager,
|
||||
VariablesEntity,
|
||||
storage_adapter,
|
||||
serializer,
|
||||
)
|
||||
self._variables_provider = StorageVariablesProvider(
|
||||
storage=storage,
|
||||
encryption=FernetEncryption(self._serve_config.encrypt_key),
|
||||
system_app=self._system_app,
|
||||
)
|
||||
|
||||
@property
|
||||
def variables_provider(self):
|
||||
"""Get the variables provider of the serve app with db storage"""
|
||||
return self._variables_provider
|
||||
|
@@ -9,7 +9,6 @@ from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.agent import AgentDummyTrigger
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
from dbgpt.core.awel.flow.flow_factory import (
|
||||
FlowCategory,
|
||||
FlowFactory,
|
||||
@@ -34,7 +33,7 @@ from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..models.models import ServeDao, ServeEntity
|
||||
|
||||
@@ -147,7 +146,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise ValueError(
|
||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||
) from e
|
||||
res = self.dao.create(request)
|
||||
self.dao.create(request)
|
||||
# Query from database
|
||||
res = self.get({"uid": request.uid})
|
||||
|
||||
state = request.state
|
||||
try:
|
||||
@@ -574,3 +575,61 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
async def debug_flow(
|
||||
self, request: FlowDebugRequest, default_incremental: Optional[bool] = None
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Debug the flow.
|
||||
|
||||
Args:
|
||||
request (FlowDebugRequest): The request
|
||||
default_incremental (Optional[bool]): The default incremental configuration
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The output
|
||||
"""
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||
|
||||
dag = self._flow_factory.build(request.flow)
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
task = cast(BaseOperator, leaf_nodes[0])
|
||||
dag_metadata = _parse_metadata(dag)
|
||||
# TODO: Run task with variables
|
||||
variables = request.variables
|
||||
dag_request = request.request
|
||||
|
||||
if isinstance(request.request, CommonLLMHttpRequestBody):
|
||||
incremental = request.request.incremental
|
||||
elif isinstance(request.request, dict):
|
||||
incremental = request.request.get("incremental", False)
|
||||
else:
|
||||
raise ValueError("Invalid request type")
|
||||
|
||||
if default_incremental is not None:
|
||||
incremental = default_incremental
|
||||
|
||||
try:
|
||||
async for output in safe_chat_stream_with_dag_task(
|
||||
task, dag_request, incremental
|
||||
):
|
||||
yield output
|
||||
except HTTPException as e:
|
||||
yield ModelOutput(error_code=1, text=e.detail, incremental=incremental)
|
||||
except Exception as e:
|
||||
yield ModelOutput(error_code=1, text=str(e), incremental=incremental)
|
||||
|
||||
async def _wrapper_chat_stream_flow_str(
|
||||
self, stream_iter: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[str]:
|
||||
|
||||
async for output in stream_iter:
|
||||
text = output.text
|
||||
if text:
|
||||
text = text.replace("\n", "\\n")
|
||||
if output.error_code != 0:
|
||||
yield f"data:[SERVER_ERROR]{text}\n\n"
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
121
dbgpt/serve/flow/service/share_utils.py
Normal file
121
dbgpt/serve/flow/service/share_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
import aiofiles
|
||||
import tomlkit
|
||||
from fastapi import UploadFile
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import blocking_func_to_async
|
||||
|
||||
from ..api.schemas import ServeRequest
|
||||
|
||||
|
||||
def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO:
|
||||
|
||||
zip_buffer = io.BytesIO()
|
||||
flow_name = flow.name
|
||||
flow_label = flow.label
|
||||
flow_description = flow.description
|
||||
dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False)
|
||||
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
|
||||
manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json"
|
||||
readme = f"# {flow_label}\n\n{flow_description}"
|
||||
zip_file.writestr(f"{package_name}/MANIFEST.in", manifest)
|
||||
zip_file.writestr(f"{package_name}/README.md", readme)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/__init__.py",
|
||||
"",
|
||||
)
|
||||
zip_file.writestr(
|
||||
f"{package_name}/{flow_name}/definition/flow_definition.json",
|
||||
dag_json,
|
||||
)
|
||||
dbgpts_toml = tomlkit.document()
|
||||
# Add flow information
|
||||
dbgpts_flow_toml = tomlkit.document()
|
||||
dbgpts_flow_toml.add("label", "Simple Streaming Chat")
|
||||
name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat")
|
||||
name_with_comment.comment("A unique name for all dbgpts")
|
||||
dbgpts_flow_toml.add("name", name_with_comment)
|
||||
|
||||
dbgpts_flow_toml.add("version", "0.1.0")
|
||||
dbgpts_flow_toml.add(
|
||||
"description",
|
||||
flow_description,
|
||||
)
|
||||
dbgpts_flow_toml.add("authors", [])
|
||||
|
||||
definition_type_with_comment = tomlkit.string("json")
|
||||
definition_type_with_comment.comment("How to define the flow, python or json")
|
||||
dbgpts_flow_toml.add("definition_type", definition_type_with_comment)
|
||||
|
||||
dbgpts_toml.add("flow", dbgpts_flow_toml)
|
||||
|
||||
# Add python and json config
|
||||
python_config = tomlkit.table()
|
||||
dbgpts_toml.add("python_config", python_config)
|
||||
|
||||
json_config = tomlkit.table()
|
||||
json_config.add("file_path", "definition/flow_definition.json")
|
||||
json_config.comment("Json config")
|
||||
|
||||
dbgpts_toml.add("json_config", json_config)
|
||||
|
||||
# Transform to string
|
||||
toml_string = tomlkit.dumps(dbgpts_toml)
|
||||
zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string)
|
||||
|
||||
pyproject_toml = tomlkit.document()
|
||||
|
||||
# Add [tool.poetry] section
|
||||
tool_poetry_toml = tomlkit.table()
|
||||
tool_poetry_toml.add("name", package_name)
|
||||
tool_poetry_toml.add("version", "0.1.0")
|
||||
tool_poetry_toml.add("description", "A dbgpts package")
|
||||
tool_poetry_toml.add("authors", [])
|
||||
tool_poetry_toml.add("readme", "README.md")
|
||||
pyproject_toml["tool"] = tomlkit.table()
|
||||
pyproject_toml["tool"]["poetry"] = tool_poetry_toml
|
||||
|
||||
# Add [tool.poetry.dependencies] section
|
||||
dependencies = tomlkit.table()
|
||||
dependencies.add("python", "^3.10")
|
||||
pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies
|
||||
|
||||
# Add [build-system] section
|
||||
build_system = tomlkit.table()
|
||||
build_system.add("requires", ["poetry-core"])
|
||||
build_system.add("build-backend", "poetry.core.masonry.api")
|
||||
pyproject_toml["build-system"] = build_system
|
||||
|
||||
# Transform to string
|
||||
pyproject_toml_string = tomlkit.dumps(pyproject_toml)
|
||||
zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string)
|
||||
zip_buffer.seek(0)
|
||||
return zip_buffer
|
||||
|
||||
|
||||
async def _parse_flow_from_zip_file(
|
||||
file: UploadFile, sys_app: SystemApp
|
||||
) -> ServeRequest:
|
||||
from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path
|
||||
|
||||
filename = file.filename
|
||||
if not filename.endswith(".zip"):
|
||||
raise ValueError("Uploaded file must be a ZIP file")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
zip_path = os.path.join(temp_dir, filename)
|
||||
|
||||
# Save uploaded file to temporary directory
|
||||
async with aiofiles.open(zip_path, "wb") as out_file:
|
||||
while content := await file.read(1024 * 64): # Read in chunks of 64KB
|
||||
await out_file.write(content)
|
||||
flow = await blocking_func_to_async(
|
||||
sys_app, _load_flow_package_from_zip_path, zip_path
|
||||
)
|
||||
return flow
|
152
dbgpt/serve/flow/service/variables_service.py
Normal file
152
dbgpt/serve/flow/service/variables_service.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt import SystemApp
|
||||
from dbgpt.core.interface.variables import StorageVariables, VariablesProvider
|
||||
from dbgpt.serve.core import BaseService
|
||||
|
||||
from ..api.schemas import VariablesRequest, VariablesResponse
|
||||
from ..config import (
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
SERVE_VARIABLES_SERVICE_COMPONENT_NAME,
|
||||
ServeConfig,
|
||||
)
|
||||
from ..models.models import VariablesDao, VariablesEntity
|
||||
|
||||
|
||||
class VariablesService(
|
||||
BaseService[VariablesEntity, VariablesRequest, VariablesResponse]
|
||||
):
|
||||
"""Variables service"""
|
||||
|
||||
name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: VariablesDao = dao
|
||||
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app
|
||||
"""
|
||||
super().init_app(system_app)
|
||||
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = self._dao or VariablesDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
def dao(self) -> VariablesDao:
|
||||
"""Returns the internal DAO."""
|
||||
return self._dao
|
||||
|
||||
@property
|
||||
def variables_provider(self) -> VariablesProvider:
|
||||
"""Returns the internal VariablesProvider.
|
||||
|
||||
Returns:
|
||||
VariablesProvider: The internal VariablesProvider
|
||||
"""
|
||||
variables_provider = VariablesProvider.get_instance(
|
||||
self._system_app, default_component=None
|
||||
)
|
||||
if variables_provider:
|
||||
return variables_provider
|
||||
else:
|
||||
from ..serve import Serve
|
||||
|
||||
variables_provider = Serve.get_instance(self._system_app).variables_provider
|
||||
self._system_app.register_instance(variables_provider)
|
||||
return variables_provider
|
||||
|
||||
@property
|
||||
def config(self) -> ServeConfig:
|
||||
"""Returns the internal ServeConfig."""
|
||||
return self._serve_config
|
||||
|
||||
def create(self, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Create a new entity
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def update(self, _: int, request: VariablesRequest) -> VariablesResponse:
|
||||
"""Update variables.
|
||||
|
||||
Args:
|
||||
request (VariablesRequest): The request
|
||||
|
||||
Returns:
|
||||
VariablesResponse: The response
|
||||
"""
|
||||
variables = StorageVariables(
|
||||
key=request.key,
|
||||
name=request.name,
|
||||
label=request.label,
|
||||
value=request.value,
|
||||
value_type=request.value_type,
|
||||
category=request.category,
|
||||
scope=request.scope,
|
||||
scope_key=request.scope_key,
|
||||
user_name=request.user_name,
|
||||
sys_code=request.sys_code,
|
||||
enabled=1 if request.enabled else 0,
|
||||
description=request.description,
|
||||
)
|
||||
exist_value = self.variables_provider.get(
|
||||
variables.identifier.str_identifier, None
|
||||
)
|
||||
if exist_value is None:
|
||||
raise ValueError(
|
||||
f"Variable {variables.identifier.str_identifier} not found"
|
||||
)
|
||||
self.variables_provider.save(variables)
|
||||
query = {
|
||||
"key": request.key,
|
||||
"name": request.name,
|
||||
"scope": request.scope,
|
||||
"scope_key": request.scope_key,
|
||||
"sys_code": request.sys_code,
|
||||
"user_name": request.user_name,
|
||||
"enabled": request.enabled,
|
||||
}
|
||||
return self.dao.get_one(query)
|
||||
|
||||
def list_all_variables(self, category: str = "common") -> List[VariablesResponse]:
|
||||
"""List all variables."""
|
||||
return self.dao.get_list({"enabled": True, "category": category})
|
Reference in New Issue
Block a user