feat(core): Add API authentication for serve template (#950)

This commit is contained in:
Fangyin Cheng
2023-12-19 13:41:02 +08:00
committed by GitHub
parent 6739993b94
commit a10d5f57b2
34 changed files with 1293 additions and 377 deletions

View File

@@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@@ -20,13 +23,78 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.post("/", response_model=Result[ServerResponse])
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -41,7 +109,9 @@ async def create(
return Result.succ(service.create(request))
@router.put("/", response_model=Result[ServerResponse])
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -56,7 +126,11 @@ async def update(
return Result.succ(service.update(request))
@router.post("/query", response_model=Result[ServerResponse])
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -71,7 +145,11 @@ async def query(
return Result.succ(service.get(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@@ -1,5 +1,6 @@
# Define your Pydantic schemas here
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
@@ -7,8 +8,13 @@ class ServeRequest(BaseModel):
# TODO define your own fields here
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
class ServerResponse(BaseModel):
"""{__template_app_name__hump__} response model"""
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
@@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -2,7 +2,13 @@ from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
class Serve(BaseComponent):

View File

@@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property

View File

@@ -0,0 +1,124 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,109 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,76 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str):
def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
for root, dirs, files in os.walk(src_dir):
dirs[:] = [d for d in dirs if not _should_ignore(d)]
relative_path = os.path.relpath(root, src_dir)
if relative_path == ".":
relative_path = ""
@@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
os.makedirs(target_dir, exist_ok=True)
for file in files:
if _should_ignore(file):
continue
try:
with open(os.path.join(root, file), "r") as f:
content = f.read()
@@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
except Exception as e:
click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}")
raise e
def _should_ignore(file_or_dir: str):
"""Return True if the given file or directory should be ignored.""" ""
ignore_patterns = [".pyc", "__pycache__"]
return any(pattern in file_or_dir for pattern in ignore_patterns)