feat(core): Add API authentication for serve template (#950)

This commit is contained in:
Fangyin Cheng
2023-12-19 13:41:02 +08:00
committed by GitHub
parent 6739993b94
commit a10d5f57b2
34 changed files with 1293 additions and 377 deletions

View File

@@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@@ -20,14 +23,79 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
# TODO: Compatible with old API, will be modified in the future
@router.post("/add", response_model=Result[ServerResponse])
@router.post(
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -42,7 +110,11 @@ async def create(
return Result.succ(service.create(request))
@router.post("/update", response_model=Result[ServerResponse])
@router.post(
"/update",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -57,7 +129,9 @@ async def update(
return Result.succ(service.update(request))
@router.post("/delete", response_model=Result[None])
@router.post(
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
)
async def delete(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[None]:
@@ -72,7 +146,11 @@ async def delete(
return Result.succ(service.delete(request))
@router.post("/list", response_model=Result[List[ServerResponse]])
@router.post(
"/list",
response_model=Result[List[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[List[ServerResponse]]:
@@ -87,7 +165,11 @@ async def query(
return Result.succ(service.get_list(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@@ -1,73 +1,78 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""Prompt request model"""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
chat_scene: Optional[str] = Field(
None,
description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.",
examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"],
)
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
sub_chat_scene: Optional[str] = Field(
None,
description="The sub chat scene.",
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
)
content: Optional[str] = None
"""
The prompt content.
"""
prompt_type: Optional[str] = Field(
None,
description="The prompt type, either common or private.",
examples=["common", "private"],
)
prompt_name: Optional[str] = Field(
None,
description="The prompt name.",
examples=["code_assistant", "joker", "data_analysis_expert"],
)
content: Optional[str] = Field(
None,
description="The prompt content.",
examples=[
"Write a qsort function in python",
"Tell me a joke about AI",
"You are a data analysis expert.",
],
)
user_name: Optional[str] = None
"""
The user name.
"""
user_name: Optional[str] = Field(
None,
description="The user name.",
examples=["zhangsan", "lisi", "wangwu"],
)
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""
sys_code: Optional[str] = Field(
None,
description="The system code.",
examples=["dbgpt", "auth_manager", "data_platform"],
)
class ServerResponse(BaseModel):
class ServerResponse(ServeRequest):
"""Prompt response model"""
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
sys_code: Optional[str] = None
"""
System code
"""
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None
id: Optional[int] = Field(
None,
description="The prompt id.",
examples=[1, 2, 3],
)
gmt_created: Optional[str] = Field(
None,
description="The prompt created time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
gmt_modified: Optional[str] = Field(
None,
description="The prompt modified time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)