Files
DB-GPT/dbgpt/serve/flow/api/endpoints.py
2024-04-20 09:41:16 +08:00

229 lines
6.3 KiB
Python

from functools import cache
from typing import List, Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
router = APIRouter()
# Add your API endpoints here
global_system_app: Optional[SystemApp] = None
def get_service() -> Service:
"""Get the service instance"""
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
request: Request = None,
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key}
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if request.url.path.startswith(f"/api/v1"):
return None
# for api_version in serve.serve_versions():
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/flows", response_model=Result[None], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new Flow entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create_and_save_dag(request))
@router.put(
"/flows/{uid}",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def update(
uid: str, request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a Flow entity
Args:
uid (str): The uid
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update_flow(request))
@router.delete("/flows/{uid}")
async def delete(
uid: str, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Delete a Flow entity
Args:
uid (str): The uid
service (Service): The service
Returns:
Result[None]: The response
"""
inst = service.delete(uid)
return Result.succ(inst)
@router.get("/flows/{uid}")
async def get_flows(
uid: str, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Get a Flow entity by uid
Args:
uid (str): The uid
service (Service): The service
Returns:
Result[ServerResponse]: The response
"""
flow = service.get({"uid": uid})
if not flow:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
return Result.succ(flow)
@router.get(
"/flows",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
name: Optional[str] = Query(default=None, description="flow name"),
uid: Optional[str] = Query(default=None, description="flow uid"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Flow entities
Args:
user_name (Optional[str]): The username
sys_code (Optional[str]): The system code
page (int): The page number
page_size (int): The page size
name (Optional[str]): The flow name
uid (Optional[str]): The flow uid
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(
service.get_list_by_page(
{"user_name": user_name, "sys_code": sys_code, "name": name, "uid": uid},
page,
page_size,
)
)
@router.get("/nodes", dependencies=[Depends(check_api_key)])
async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]:
"""Get the operator or resource nodes
Returns:
Result[List[Union[ViewMetadata, ResourceMetadata]]]:
The operator or resource nodes
"""
from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY
return Result.succ(_OPERATOR_REGISTRY.metadata_list())
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app
system_app.register(Service)
global_system_app = system_app