mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
feat(core): Add API authentication for serve template (#950)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from functools import cache
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.serve.core import Result
|
||||
@@ -20,13 +23,78 @@ def get_service() -> Service:
|
||||
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
|
||||
|
||||
|
||||
get_bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
@cache
|
||||
def _parse_api_keys(api_keys: str) -> List[str]:
|
||||
"""Parse the string api keys to a list
|
||||
|
||||
Args:
|
||||
api_keys (str): The string api keys
|
||||
|
||||
Returns:
|
||||
List[str]: The list of api keys
|
||||
"""
|
||||
if not api_keys:
|
||||
return []
|
||||
return [key.strip() for key in api_keys.split(",")]
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Optional[str]:
|
||||
"""Check the api key
|
||||
|
||||
If the api key is not set, allow all.
|
||||
|
||||
Your can pass the token in you request header like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import requests
|
||||
client_api_key = "your_api_key"
|
||||
headers = {"Authorization": "Bearer " + client_api_key }
|
||||
res = requests.get("http://test/hello", headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
"""
|
||||
if service.config.api_keys:
|
||||
api_keys = _parse_api_keys(service.config.api_keys)
|
||||
if auth is None or (token := auth.credentials) not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
return token
|
||||
else:
|
||||
# api_keys not set; allow all
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/", response_model=Result[ServerResponse])
|
||||
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
|
||||
async def test_auth():
|
||||
"""Test auth endpoint"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def create(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@@ -41,7 +109,9 @@ async def create(
|
||||
return Result.succ(service.create(request))
|
||||
|
||||
|
||||
@router.put("/", response_model=Result[ServerResponse])
|
||||
@router.put(
|
||||
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
|
||||
)
|
||||
async def update(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@@ -56,7 +126,11 @@ async def update(
|
||||
return Result.succ(service.update(request))
|
||||
|
||||
|
||||
@router.post("/query", response_model=Result[ServerResponse])
|
||||
@router.post(
|
||||
"/query",
|
||||
response_model=Result[ServerResponse],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query(
|
||||
request: ServeRequest, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
@@ -71,7 +145,11 @@ async def query(
|
||||
return Result.succ(service.get(request))
|
||||
|
||||
|
||||
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
|
||||
@router.post(
|
||||
"/query_page",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_page(
|
||||
request: ServeRequest,
|
||||
page: Optional[int] = Query(default=1, description="current page"),
|
||||
|
@@ -1,5 +1,6 @@
|
||||
# Define your Pydantic schemas here
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from ..config import SERVE_APP_NAME_HUMP
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
@@ -7,8 +8,13 @@ class ServeRequest(BaseModel):
|
||||
|
||||
# TODO define your own fields here
|
||||
|
||||
class Config:
|
||||
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
|
||||
|
||||
|
||||
class ServerResponse(BaseModel):
|
||||
"""{__template_app_name__hump__} response model"""
|
||||
|
||||
# TODO define your own fields here
|
||||
class Config:
|
||||
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from dbgpt.serve.core import BaseServeConfig
|
||||
|
||||
@@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig):
|
||||
"""Parameters for the serve command"""
|
||||
|
||||
# TODO: add your own parameters here
|
||||
api_keys: Optional[str] = field(
|
||||
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
|
||||
)
|
||||
|
@@ -2,7 +2,13 @@ from typing import List, Optional
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
|
||||
from .api.endpoints import router, init_endpoints
|
||||
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
|
||||
from .config import (
|
||||
SERVE_APP_NAME,
|
||||
SERVE_APP_NAME_HUMP,
|
||||
APP_NAME,
|
||||
SERVE_CONFIG_KEY_PREFIX,
|
||||
ServeConfig,
|
||||
)
|
||||
|
||||
|
||||
class Serve(BaseComponent):
|
||||
|
@@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
|
||||
name = SERVE_SERVICE_COMPONENT_NAME
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
|
||||
self._system_app = None
|
||||
self._serve_config: ServeConfig = None
|
||||
self._dao: ServeDao = None
|
||||
self._dao: ServeDao = dao
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp) -> None:
|
||||
@@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
self._serve_config = ServeConfig.from_app_config(
|
||||
system_app.config, SERVE_CONFIG_KEY_PREFIX
|
||||
)
|
||||
self._dao = ServeDao(self._serve_config)
|
||||
self._dao = self._dao or ServeDao(self._serve_config)
|
||||
self._system_app = system_app
|
||||
|
||||
@property
|
||||
|
@@ -0,0 +1,124 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_health(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
# TODO: add your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@@ -0,0 +1,109 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
# TODO : build your server config
|
||||
return ServeConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dao(server_config):
|
||||
return ServeDao(server_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
def test_table_exist():
|
||||
assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
# TODO: implement your test case
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
# TODO: implement your test case
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_dao_get_one(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
|
||||
|
||||
def test_get_dao_get_list(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_update(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_delete(dao, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_dao_get_list_page(dao):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@@ -0,0 +1,76 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
# TODO: build your default entity dict
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
# ...
|
||||
pass
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
# TODO: implement your test case
|
||||
pass
|
||||
|
||||
|
||||
# Add more test cases according to your own logic
|
@@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str):
|
||||
|
||||
def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
for root, dirs, files in os.walk(src_dir):
|
||||
dirs[:] = [d for d in dirs if not _should_ignore(d)]
|
||||
relative_path = os.path.relpath(root, src_dir)
|
||||
if relative_path == ".":
|
||||
relative_path = ""
|
||||
@@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
for file in files:
|
||||
if _should_ignore(file):
|
||||
continue
|
||||
try:
|
||||
with open(os.path.join(root, file), "r") as f:
|
||||
content = f.read()
|
||||
@@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
|
||||
except Exception as e:
|
||||
click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}")
|
||||
raise e
|
||||
|
||||
|
||||
def _should_ignore(file_or_dir: str):
|
||||
"""Return True if the given file or directory should be ignored.""" ""
|
||||
ignore_patterns = [".pyc", "__pycache__"]
|
||||
return any(pattern in file_or_dir for pattern in ignore_patterns)
|
||||
|
Reference in New Issue
Block a user