mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
591 lines
18 KiB
Python
591 lines
18 KiB
Python
import io
|
|
import json
|
|
from functools import cache
|
|
from typing import Dict, List, Literal, Optional, Union
|
|
|
|
from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile
|
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|
|
|
from dbgpt.component import SystemApp
|
|
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
|
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
|
from dbgpt.serve.core import Result, blocking_func_to_async
|
|
from dbgpt.util import PaginationResult
|
|
|
|
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
|
from ..service.service import Service, _parse_flow_template_from_json
|
|
from ..service.variables_service import VariablesService
|
|
from .schemas import (
|
|
FlowDebugRequest,
|
|
RefreshNodeRequest,
|
|
ServeRequest,
|
|
ServerResponse,
|
|
VariablesKeyResponse,
|
|
VariablesRequest,
|
|
VariablesResponse,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# Add your API endpoints here
|
|
|
|
global_system_app: Optional[SystemApp] = None
|
|
|
|
|
|
def get_service() -> Service:
|
|
"""Get the service instance"""
|
|
return Service.get_instance(global_system_app)
|
|
|
|
|
|
def get_variable_service() -> VariablesService:
|
|
"""Get the service instance"""
|
|
return VariablesService.get_instance(global_system_app)
|
|
|
|
|
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@cache
|
|
def _parse_api_keys(api_keys: str) -> List[str]:
|
|
"""Parse the string api keys to a list
|
|
|
|
Args:
|
|
api_keys (str): The string api keys
|
|
|
|
Returns:
|
|
List[str]: The list of api keys
|
|
"""
|
|
if not api_keys:
|
|
return []
|
|
return [key.strip() for key in api_keys.split(",")]
|
|
|
|
|
|
async def check_api_key(
|
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
|
request: Request = None,
|
|
service: Service = Depends(get_service),
|
|
) -> Optional[str]:
|
|
"""Check the api key
|
|
|
|
If the api key is not set, allow all.
|
|
|
|
Your can pass the token in you request header like this:
|
|
|
|
.. code-block:: python
|
|
|
|
import requests
|
|
|
|
client_api_key = "your_api_key"
|
|
headers = {"Authorization": "Bearer " + client_api_key}
|
|
res = requests.get("http://test/hello", headers=headers)
|
|
assert res.status_code == 200
|
|
|
|
"""
|
|
if request.url.path.startswith(f"/api/v1"):
|
|
return None
|
|
|
|
# for api_version in serve.serve_versions():
|
|
if service.config.api_keys:
|
|
api_keys = _parse_api_keys(service.config.api_keys)
|
|
if auth is None or (token := auth.credentials) not in api_keys:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail={
|
|
"error": {
|
|
"message": "",
|
|
"type": "invalid_request_error",
|
|
"param": None,
|
|
"code": "invalid_api_key",
|
|
}
|
|
},
|
|
)
|
|
return token
|
|
else:
|
|
# api_keys not set; allow all
|
|
return None
|
|
|
|
|
|
@router.get("/health")
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
|
async def test_auth():
|
|
"""Test auth endpoint"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post(
|
|
"/flows",
|
|
response_model=Result[ServerResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def create(
|
|
request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[ServerResponse]:
|
|
"""Create a new Flow entity
|
|
|
|
Args:
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
res = await blocking_func_to_async(
|
|
global_system_app, service.create_and_save_dag, request
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
@router.put(
|
|
"/flows/{uid}",
|
|
response_model=Result[ServerResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def update(
|
|
uid: str, request: ServeRequest, service: Service = Depends(get_service)
|
|
) -> Result[ServerResponse]:
|
|
"""Update a Flow entity
|
|
|
|
Args:
|
|
uid (str): The uid
|
|
request (ServeRequest): The request
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
try:
|
|
res = await blocking_func_to_async(
|
|
global_system_app, service.update_flow, request
|
|
)
|
|
return Result.succ(res)
|
|
except Exception as e:
|
|
return Result.failed(msg=str(e))
|
|
|
|
|
|
@router.delete("/flows/{uid}")
|
|
async def delete(
|
|
uid: str, service: Service = Depends(get_service)
|
|
) -> Result[ServerResponse]:
|
|
"""Delete a Flow entity
|
|
|
|
Args:
|
|
uid (str): The uid
|
|
service (Service): The service
|
|
Returns:
|
|
Result[None]: The response
|
|
"""
|
|
inst = service.delete(uid)
|
|
return Result.succ(inst)
|
|
|
|
|
|
@router.get("/flows/{uid}")
|
|
async def get_flows(uid: str, service: Service = Depends(get_service)):
|
|
"""Get a Flow entity by uid
|
|
|
|
Args:
|
|
uid (str): The uid
|
|
service (Service): The service
|
|
|
|
Returns:
|
|
Result[ServerResponse]: The response
|
|
"""
|
|
flow = service.get({"uid": uid})
|
|
if not flow:
|
|
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
|
return Result.succ(flow.model_dump())
|
|
|
|
|
|
@router.get(
|
|
"/chat/flows",
|
|
response_model=Result[PaginationResult[ServerResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def query_chat_flows(
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
page: int = Query(default=1, description="current page"),
|
|
page_size: int = Query(default=20, description="page size"),
|
|
name: Optional[str] = Query(default=None, description="flow name"),
|
|
uid: Optional[str] = Query(default=None, description="flow uid"),
|
|
service: Service = Depends(get_service),
|
|
) -> Result[PaginationResult[ServerResponse]]:
|
|
return Result.succ(
|
|
service.get_list_by_page(
|
|
{
|
|
"user_name": user_name,
|
|
"sys_code": sys_code,
|
|
"name": name,
|
|
"uid": uid,
|
|
"flow_category": [FlowCategory.CHAT_AGENT, FlowCategory.CHAT_FLOW],
|
|
},
|
|
page,
|
|
page_size,
|
|
)
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/flows",
|
|
response_model=Result[PaginationResult[ServerResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def query_page(
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
page: int = Query(default=1, description="current page"),
|
|
page_size: int = Query(default=20, description="page size"),
|
|
name: Optional[str] = Query(default=None, description="flow name"),
|
|
uid: Optional[str] = Query(default=None, description="flow uid"),
|
|
service: Service = Depends(get_service),
|
|
) -> Result[PaginationResult[ServerResponse]]:
|
|
"""Query Flow entities
|
|
|
|
Args:
|
|
user_name (Optional[str]): The username
|
|
sys_code (Optional[str]): The system code
|
|
page (int): The page number
|
|
page_size (int): The page size
|
|
name (Optional[str]): The flow name
|
|
uid (Optional[str]): The flow uid
|
|
service (Service): The service
|
|
Returns:
|
|
ServerResponse: The response
|
|
"""
|
|
return Result.succ(
|
|
service.get_list_by_page(
|
|
{"user_name": user_name, "sys_code": sys_code, "name": name, "uid": uid},
|
|
page,
|
|
page_size,
|
|
)
|
|
)
|
|
|
|
|
|
@router.get("/nodes", dependencies=[Depends(check_api_key)])
|
|
async def get_nodes(
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
tags: Optional[str] = Query(default=None, description="tags"),
|
|
):
|
|
"""Get the operator or resource nodes
|
|
|
|
Args:
|
|
user_name (Optional[str]): The username
|
|
sys_code (Optional[str]): The system code
|
|
tags (Optional[str]): The tags encoded in JSON format
|
|
|
|
Returns:
|
|
Result[List[Union[ViewMetadata, ResourceMetadata]]]:
|
|
The operator or resource nodes
|
|
"""
|
|
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
|
|
|
tags_dict: Optional[Dict[str, str]] = None
|
|
if tags:
|
|
try:
|
|
tags_dict = json.loads(tags)
|
|
except json.JSONDecodeError:
|
|
return Result.fail("Invalid JSON format for tags")
|
|
|
|
metadata_list = await blocking_func_to_async(
|
|
global_system_app,
|
|
_OPERATOR_REGISTRY.metadata_list,
|
|
tags_dict,
|
|
user_name,
|
|
sys_code,
|
|
)
|
|
return Result.succ(metadata_list)
|
|
|
|
|
|
@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)])
|
|
async def refresh_nodes(refresh_request: RefreshNodeRequest):
|
|
"""Refresh the operator or resource nodes
|
|
|
|
Returns:
|
|
Result[None]: The response
|
|
"""
|
|
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
|
|
|
|
# Make sure the variables provider is initialized
|
|
_ = get_variable_service().variables_provider
|
|
|
|
new_metadata = await _OPERATOR_REGISTRY.refresh(
|
|
refresh_request.id,
|
|
refresh_request.flow_type == "operator",
|
|
refresh_request.refresh,
|
|
"http",
|
|
global_system_app,
|
|
)
|
|
return Result.succ(new_metadata)
|
|
|
|
|
|
@router.post(
|
|
"/variables",
|
|
response_model=Result[VariablesResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def create_variables(
|
|
variables_request: VariablesRequest,
|
|
) -> Result[VariablesResponse]:
|
|
"""Create a new Variables entity
|
|
|
|
Args:
|
|
variables_request (VariablesRequest): The request
|
|
Returns:
|
|
VariablesResponse: The response
|
|
"""
|
|
res = await blocking_func_to_async(
|
|
global_system_app, get_variable_service().create, variables_request
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
@router.put(
|
|
"/variables/{v_id}",
|
|
response_model=Result[VariablesResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def update_variables(
|
|
v_id: int, variables_request: VariablesRequest
|
|
) -> Result[VariablesResponse]:
|
|
"""Update a Variables entity
|
|
|
|
Args:
|
|
v_id (int): The variable id
|
|
variables_request (VariablesRequest): The request
|
|
Returns:
|
|
VariablesResponse: The response
|
|
"""
|
|
res = await blocking_func_to_async(
|
|
global_system_app, get_variable_service().update, v_id, variables_request
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
@router.get(
|
|
"/variables",
|
|
response_model=Result[PaginationResult[VariablesResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def get_variables_by_keys(
|
|
key: str = Query(..., description="variable key"),
|
|
scope: Optional[str] = Query(default=None, description="scope"),
|
|
scope_key: Optional[str] = Query(default=None, description="scope key"),
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
page: int = Query(default=1, description="current page"),
|
|
page_size: int = Query(default=20, description="page size"),
|
|
) -> Result[PaginationResult[VariablesResponse]]:
|
|
"""Get the variables by keys
|
|
|
|
Returns:
|
|
VariablesResponse: The response
|
|
"""
|
|
res = await get_variable_service().get_list_by_page(
|
|
key,
|
|
scope,
|
|
scope_key,
|
|
user_name,
|
|
sys_code,
|
|
page,
|
|
page_size,
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
@router.get(
|
|
"/variables/keys",
|
|
response_model=Result[List[VariablesKeyResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def get_variables_keys(
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
category: Optional[str] = Query(default=None, description="category"),
|
|
) -> Result[List[VariablesKeyResponse]]:
|
|
"""Get the variable keys
|
|
|
|
Returns:
|
|
VariablesKeyResponse: The response
|
|
"""
|
|
res = await blocking_func_to_async(
|
|
global_system_app,
|
|
get_variable_service().list_keys,
|
|
user_name,
|
|
sys_code,
|
|
category,
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
@router.post("/flow/debug", dependencies=[Depends(check_api_key)])
|
|
async def debug_flow(
|
|
flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service)
|
|
):
|
|
"""Run the flow in debug mode."""
|
|
# Return the no-incremental stream by default
|
|
stream_iter = service.debug_flow(flow_debug_request, default_incremental=False)
|
|
|
|
headers = {
|
|
"Content-Type": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Transfer-Encoding": "chunked",
|
|
}
|
|
return StreamingResponse(
|
|
service._wrapper_chat_stream_flow_str(stream_iter),
|
|
headers=headers,
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
|
|
@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)])
|
|
async def export_flow(
|
|
uid: str,
|
|
export_type: Literal["json", "dbgpts"] = Query(
|
|
"json", description="export type(json or dbgpts)"
|
|
),
|
|
format: Literal["file", "json"] = Query(
|
|
"file", description="response format(file or json)"
|
|
),
|
|
file_name: Optional[str] = Query(default=None, description="file name to export"),
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
service: Service = Depends(get_service),
|
|
):
|
|
"""Export the flow to a file."""
|
|
flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code})
|
|
if not flow:
|
|
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
|
package_name = flow.name.replace("_", "-")
|
|
file_name = file_name or package_name
|
|
if export_type == "json":
|
|
flow_dict = {"flow": flow.to_dict()}
|
|
if format == "json":
|
|
return JSONResponse(content=flow_dict)
|
|
else:
|
|
# Return the json file
|
|
return StreamingResponse(
|
|
io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")),
|
|
media_type="application/file",
|
|
headers={
|
|
"Content-Disposition": f"attachment;filename={file_name}.json"
|
|
},
|
|
)
|
|
|
|
elif export_type == "dbgpts":
|
|
from ..service.share_utils import _generate_dbgpts_zip
|
|
|
|
if format == "json":
|
|
raise HTTPException(
|
|
status_code=400, detail="json response is not supported for dbgpts"
|
|
)
|
|
|
|
zip_buffer = await blocking_func_to_async(
|
|
global_system_app, _generate_dbgpts_zip, package_name, flow
|
|
)
|
|
return StreamingResponse(
|
|
zip_buffer,
|
|
media_type="application/x-zip-compressed",
|
|
headers={"Content-Disposition": f"attachment;filename={file_name}.zip"},
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/flow/import",
|
|
response_model=Result[ServerResponse],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def import_flow(
|
|
file: UploadFile = File(...),
|
|
save_flow: bool = Query(
|
|
False, description="Whether to save the flow after importing"
|
|
),
|
|
service: Service = Depends(get_service),
|
|
):
|
|
"""Import the flow from a file."""
|
|
filename = file.filename
|
|
file_extension = filename.split(".")[-1].lower()
|
|
if file_extension == "json":
|
|
# Handle json file
|
|
json_content = await file.read()
|
|
json_dict = json.loads(json_content)
|
|
if "flow" not in json_dict:
|
|
raise HTTPException(
|
|
status_code=400, detail="invalid json file, missing 'flow' key"
|
|
)
|
|
flow = _parse_flow_template_from_json(json_dict)
|
|
elif file_extension == "zip":
|
|
from ..service.share_utils import _parse_flow_from_zip_file
|
|
|
|
# Handle zip file
|
|
flow = await _parse_flow_from_zip_file(file, global_system_app)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"invalid file extension {file_extension}"
|
|
)
|
|
if save_flow:
|
|
res = await blocking_func_to_async(
|
|
global_system_app, service.create_and_save_dag, flow
|
|
)
|
|
return Result.succ(res)
|
|
else:
|
|
return Result.succ(flow)
|
|
|
|
|
|
@router.get(
|
|
"/flow/templates",
|
|
response_model=Result[PaginationResult[ServerResponse]],
|
|
dependencies=[Depends(check_api_key)],
|
|
)
|
|
async def query_flow_templates(
|
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
|
page: int = Query(default=1, description="current page"),
|
|
page_size: int = Query(default=20, description="page size"),
|
|
service: Service = Depends(get_service),
|
|
) -> Result[PaginationResult[ServerResponse]]:
|
|
"""Query Flow templates."""
|
|
|
|
res = await blocking_func_to_async(
|
|
global_system_app,
|
|
service.get_flow_templates,
|
|
user_name,
|
|
sys_code,
|
|
page,
|
|
page_size,
|
|
)
|
|
return Result.succ(res)
|
|
|
|
|
|
def init_endpoints(system_app: SystemApp) -> None:
|
|
"""Initialize the endpoints"""
|
|
from .variables_provider import (
|
|
BuiltinAgentsVariablesProvider,
|
|
BuiltinAllSecretVariablesProvider,
|
|
BuiltinAllVariablesProvider,
|
|
BuiltinDatasourceVariablesProvider,
|
|
BuiltinEmbeddingsVariablesProvider,
|
|
BuiltinFlowVariablesProvider,
|
|
BuiltinKnowledgeSpacesVariablesProvider,
|
|
BuiltinLLMVariablesProvider,
|
|
BuiltinNodeVariablesProvider,
|
|
)
|
|
|
|
global global_system_app
|
|
system_app.register(Service)
|
|
system_app.register(VariablesService)
|
|
system_app.register(BuiltinFlowVariablesProvider)
|
|
system_app.register(BuiltinNodeVariablesProvider)
|
|
system_app.register(BuiltinAllVariablesProvider)
|
|
system_app.register(BuiltinAllSecretVariablesProvider)
|
|
system_app.register(BuiltinLLMVariablesProvider)
|
|
system_app.register(BuiltinEmbeddingsVariablesProvider)
|
|
system_app.register(BuiltinDatasourceVariablesProvider)
|
|
system_app.register(BuiltinAgentsVariablesProvider)
|
|
system_app.register(BuiltinKnowledgeSpacesVariablesProvider)
|
|
global_system_app = system_app
|