mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
feat(core): Add API authentication for serve template (#950)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
@@ -20,14 +23,79 @@ def get_service() -> Service:
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key }
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# TODO: Compatible with old API, will be modified in the future
|
||||
@router.post("/add", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@@ -42,7 +110,11 @@ async def create(
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.post("/update", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/update",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@@ -57,7 +129,9 @@ async def update(
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post("/delete", response_model=Result[None])
|
||||
@router.post(
|
||||
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def delete(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[None]:
|
||||
@@ -72,7 +146,11 @@ async def delete(
|
||||
return Result.succ(service.delete(request))
|
||||
|
||||
|
||||
@router.post("/list", response_model=Result[List[ServerResponse]])
|
||||
@router.post(
|
||||
"/list",
|
||||
response_model=Result[List[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[List[ServerResponse]]:
|
||||
@@ -87,7 +165,11 @@ async def query(
|
||||
return Result.succ(service.get_list(request))
|
||||
|
||||
|
||||
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
|
@@ -1,73 +1,78 @@
|
||||
# Define your Pydantic schemas here
|
||||
from typing import Optional
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
"""Prompt request model"""
|
||||
|
||||
chat_scene: Optional[str] = None
|
||||
"""
|
||||
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
|
||||
"""
|
||||
class Config:
|
||||
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
sub_chat_scene: Optional[str] = None
|
||||
"""
|
||||
The sub chat scene.
|
||||
"""
|
||||
chat_scene: Optional[str] = Field(
|
||||
None,
|
||||
description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.",
|
||||
examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"],
|
||||
)
|
||||
|
||||
prompt_type: Optional[str] = None
|
||||
"""
|
||||
The prompt type, either common or private.
|
||||
"""
|
||||
sub_chat_scene: Optional[str] = Field(
|
||||
None,
|
||||
description="The sub chat scene.",
|
||||
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
|
||||
)
|
||||
|
||||
content: Optional[str] = None
|
||||
"""
|
||||
The prompt content.
|
||||
"""
|
||||
prompt_type: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt type, either common or private.",
|
||||
examples=["common", "private"],
|
||||
)
|
||||
prompt_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt name.",
|
||||
examples=["code_assistant", "joker", "data_analysis_expert"],
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt content.",
|
||||
examples=[
|
||||
"Write a qsort function in python",
|
||||
"Tell me a joke about AI",
|
||||
"You are a data analysis expert.",
|
||||
],
|
||||
)
|
||||
|
||||
user_name: Optional[str] = None
|
||||
"""
|
||||
The user name.
|
||||
"""
|
||||
user_name: Optional[str] = Field(
|
||||
None,
|
||||
description="The user name.",
|
||||
examples=["zhangsan", "lisi", "wangwu"],
|
||||
)
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""
|
||||
System code
|
||||
"""
|
||||
|
||||
prompt_name: Optional[str] = None
|
||||
"""
|
||||
The prompt name.
|
||||
"""
|
||||
sys_code: Optional[str] = Field(
|
||||
None,
|
||||
description="The system code.",
|
||||
examples=["dbgpt", "auth_manager", "data_platform"],
|
||||
)
|
||||
|
||||
|
||||
class ServerResponse(BaseModel):
|
||||
class ServerResponse(ServeRequest):
|
||||
"""Prompt response model"""
|
||||
|
||||
id: int = None
|
||||
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
|
||||
class Config:
|
||||
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
chat_scene: str = None
|
||||
|
||||
"""sub_chat_scene: sub chat scene"""
|
||||
sub_chat_scene: str = None
|
||||
|
||||
"""prompt_type: common or private"""
|
||||
prompt_type: str = None
|
||||
|
||||
"""content: prompt content"""
|
||||
content: str = None
|
||||
|
||||
"""user_name: user name"""
|
||||
user_name: str = None
|
||||
|
||||
sys_code: Optional[str] = None
|
||||
"""
|
||||
System code
|
||||
"""
|
||||
|
||||
"""prompt_name: prompt name"""
|
||||
prompt_name: str = None
|
||||
gmt_created: str = None
|
||||
gmt_modified: str = None
|
||||
id: Optional[int] = Field(
|
||||
None,
|
||||
description="The prompt id.",
|
||||
examples=[1, 2, 3],
|
||||
)
|
||||
gmt_created: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt created time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
gmt_modified: Optional[str] = Field(
|
||||
None,
|
||||
description="The prompt modified time.",
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
|
Reference in New Issue
Block a user