feat(core): Add API authentication for serve template (#950)

This commit is contained in:
Fangyin Cheng
2023-12-19 13:41:02 +08:00
committed by GitHub
parent 6739993b94
commit a10d5f57b2
34 changed files with 1293 additions and 377 deletions

View File

@@ -16,4 +16,6 @@ class BaseServeConfig(BaseParameters):
config_prefix (str): Configuration prefix
"""
config_dict = config.get_all_by_prefix(config_prefix)
# remove prefix
config_dict = {k[len(config_prefix) :]: v for k, v in config_dict.items()}
return cls(**config_dict)

View File

View File

@@ -0,0 +1,59 @@
import pytest
import pytest_asyncio
from typing import Dict
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient
from dbgpt.component import SystemApp
from dbgpt.util import AppConfig
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
def create_system_app(param: Dict) -> SystemApp:
app_config = param.get("app_config", {})
if isinstance(app_config, dict):
app_config = AppConfig(configs=app_config)
elif not isinstance(app_config, AppConfig):
raise RuntimeError("app_config must be AppConfig or dict")
return SystemApp(app, app_config)
@pytest_asyncio.fixture
async def asystem_app(request):
param = getattr(request, "param", {})
return create_system_app(param)
@pytest.fixture
def system_app(request):
param = getattr(request, "param", {})
return create_system_app(param)
@pytest_asyncio.fixture
async def client(request, asystem_app: SystemApp):
param = getattr(request, "param", {})
headers = param.get("headers", {})
base_url = param.get("base_url", "http://test")
client_api_key = param.get("client_api_key")
routers = param.get("routers", [])
app_caller = param.get("app_caller")
if "api_keys" in param:
del param["api_keys"]
if client_api_key:
headers["Authorization"] = "Bearer " + client_api_key
async with AsyncClient(app=app, base_url=base_url, headers=headers) as client:
for router in routers:
app.include_router(router)
if app_caller:
app_caller(app, asystem_app)
yield client

View File

@@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@@ -20,14 +23,79 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
# TODO: Compatible with old API, will be modified in the future
@router.post("/add", response_model=Result[ServerResponse])
@router.post(
"/add", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -42,7 +110,11 @@ async def create(
return Result.succ(service.create(request))
@router.post("/update", response_model=Result[ServerResponse])
@router.post(
"/update",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -57,7 +129,9 @@ async def update(
return Result.succ(service.update(request))
@router.post("/delete", response_model=Result[None])
@router.post(
"/delete", response_model=Result[None], dependencies=[Depends(check_api_key)]
)
async def delete(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[None]:
@@ -72,7 +146,11 @@ async def delete(
return Result.succ(service.delete(request))
@router.post("/list", response_model=Result[List[ServerResponse]])
@router.post(
"/list",
response_model=Result[List[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[List[ServerResponse]]:
@@ -87,7 +165,11 @@ async def query(
return Result.succ(service.get_list(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@@ -1,73 +1,78 @@
# Define your Pydantic schemas here
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""Prompt request model"""
chat_scene: Optional[str] = None
"""
The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.
"""
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
sub_chat_scene: Optional[str] = None
"""
The sub chat scene.
"""
chat_scene: Optional[str] = Field(
None,
description="The chat scene, e.g. chat_with_db_execute, chat_excel, chat_with_db_qa.",
examples=["chat_with_db_execute", "chat_excel", "chat_with_db_qa"],
)
prompt_type: Optional[str] = None
"""
The prompt type, either common or private.
"""
sub_chat_scene: Optional[str] = Field(
None,
description="The sub chat scene.",
examples=["sub_scene_1", "sub_scene_2", "sub_scene_3"],
)
content: Optional[str] = None
"""
The prompt content.
"""
prompt_type: Optional[str] = Field(
None,
description="The prompt type, either common or private.",
examples=["common", "private"],
)
prompt_name: Optional[str] = Field(
None,
description="The prompt name.",
examples=["code_assistant", "joker", "data_analysis_expert"],
)
content: Optional[str] = Field(
None,
description="The prompt content.",
examples=[
"Write a qsort function in python",
"Tell me a joke about AI",
"You are a data analysis expert.",
],
)
user_name: Optional[str] = None
"""
The user name.
"""
user_name: Optional[str] = Field(
None,
description="The user name.",
examples=["zhangsan", "lisi", "wangwu"],
)
sys_code: Optional[str] = None
"""
System code
"""
prompt_name: Optional[str] = None
"""
The prompt name.
"""
sys_code: Optional[str] = Field(
None,
description="The system code.",
examples=["dbgpt", "auth_manager", "data_platform"],
)
class ServerResponse(BaseModel):
class ServerResponse(ServeRequest):
"""Prompt response model"""
id: int = None
"""chat_scene: for example: chat_with_db_execute, chat_excel, chat_with_db_qa"""
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
chat_scene: str = None
"""sub_chat_scene: sub chat scene"""
sub_chat_scene: str = None
"""prompt_type: common or private"""
prompt_type: str = None
"""content: prompt content"""
content: str = None
"""user_name: user name"""
user_name: str = None
sys_code: Optional[str] = None
"""
System code
"""
"""prompt_name: prompt name"""
prompt_name: str = None
gmt_created: str = None
gmt_modified: str = None
id: Optional[int] = Field(
None,
description="The prompt id.",
examples=[1, 2, 3],
)
gmt_created: Optional[str] = Field(
None,
description="The prompt created time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
gmt_modified: Optional[str] = Field(
None,
description="The prompt modified time.",
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
@@ -17,3 +18,15 @@ class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
default_user: Optional[str] = field(
default=None,
metadata={"help": "Default user name for prompt"},
)
default_sys_code: Optional[str] = field(
default=None,
metadata={"help": "Default system code for prompt"},
)

View File

@@ -2,7 +2,13 @@ from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
class Serve(BaseComponent):

View File

@@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property
@@ -41,6 +41,22 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: ServeRequest) -> ServerResponse:
"""Create a new Prompt entity
Args:
request (ServeRequest): The request
Returns:
ServerResponse: The response
"""
if not request.user_name:
request.user_name = self.config.default_user
if not request.sys_code:
request.sys_code = self.config.default_sys_code
return super().create(request)
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Prompt entity

View File

View File

@@ -0,0 +1,176 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
async def _create_and_validate(
client: AsyncClient, sys_code: str, content: str, expect_id: int = 1, **kwargs
):
req_json = {"sys_code": sys_code, "content": content}
req_json.update(kwargs)
response = await client.post("/add", json=req_json)
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
res_obj = ServerResponse(**data)
assert res_obj.id == expect_id
assert res_obj.sys_code == sys_code
assert res_obj.content == content
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_auth(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
await _create_and_validate(client, "test", "test")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
await _create_and_validate(client, "test", "test")
response = await client.post("/update", json={"id": 1, "content": "test2"})
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
res_obj = ServerResponse(**data)
assert res_obj.id == 1
assert res_obj.sys_code == "test"
assert res_obj.content == "test2"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
for i in range(10):
await _create_and_validate(
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
)
response = await client.post("/list", json={"sys_code": "test"})
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
assert len(data) == 10
res_obj = ServerResponse(**data[0])
assert res_obj.id == 1
assert res_obj.sys_code == "test"
assert res_obj.content == "test0"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
for i in range(10):
await _create_and_validate(
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
)
response = await client.post(
"/query_page", params={"page": 1, "page_size": 5}, json={"sys_code": "test"}
)
assert response.status_code == 200
json_res = response.json()
assert "success" in json_res and json_res["success"]
assert "data" in json_res and json_res["data"]
data = json_res["data"]
page_result: PaginationResult = PaginationResult(**data)
assert page_result.total_count == 10
assert page_result.total_pages == 2
assert page_result.page == 1
assert page_result.page_size == 5
assert len(page_result.items) == 5

View File

@@ -0,0 +1,257 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
return {
"chat_scene": "chat_data",
"sub_chat_scene": "excel",
"prompt_type": "common",
"prompt_name": "my_prompt_1",
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_unique_key(default_entity_dict):
ServeEntity.create(**default_entity_dict)
with pytest.raises(Exception):
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_update(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.update(prompt_name="my_prompt_2")
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_2"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_entity_delete(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
for i in range(10):
ServeEntity.create(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan",
sys_code="dbgpt",
)
entities = ServeEntity.all()
assert len(entities) == 10
for entity in entities:
assert entity.chat_scene == "chat_data"
assert entity.sub_chat_scene == "excel"
assert entity.prompt_type == "common"
assert entity.content == "Write a qsort function in python."
assert entity.user_name == "zhangsan"
assert entity.sys_code == "dbgpt"
assert entity.gmt_created is not None
assert entity.gmt_modified is not None
def test_dao_create(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_1"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_dao_get_one(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
res: ServerResponse = dao.get_one(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_1"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_get_dao_get_list(dao):
for i in range(10):
dao.create(
ServeRequest(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan" if i % 2 == 0 else "lisi",
sys_code="dbgpt",
)
)
res: List[ServerResponse] = dao.get_list({"sys_code": "dbgpt"})
assert len(res) == 10
for i, r in enumerate(res):
assert r.id == i + 1
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.prompt_name == f"my_prompt_{i}"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
assert r.sys_code == "dbgpt"
half_res: List[ServerResponse] = dao.get_list({"user_name": "zhangsan"})
assert len(half_res) == 5
def test_dao_update(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
res: ServerResponse = dao.update(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"},
ServeRequest(prompt_name="my_prompt_2"),
)
assert res is not None
assert res.id == 1
assert res.chat_scene == "chat_data"
assert res.sub_chat_scene == "excel"
assert res.prompt_type == "common"
assert res.prompt_name == "my_prompt_2"
assert res.content == "Write a qsort function in python."
assert res.user_name == "zhangsan"
assert res.sys_code == "dbgpt"
def test_dao_delete(dao, default_entity_dict):
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
dao.delete({"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
res: ServerResponse = dao.get_one(
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
)
assert res is None
def test_dao_get_list_page(dao):
for i in range(20):
dao.create(
ServeRequest(
chat_scene="chat_data",
sub_chat_scene="excel",
prompt_type="common",
prompt_name=f"my_prompt_{i}",
content="Write a qsort function in python.",
user_name="zhangsan" if i % 2 == 0 else "lisi",
sys_code="dbgpt",
)
)
res = dao.get_list_page({"sys_code": "dbgpt"}, page=1, page_size=8)
assert res.total_count == 20
assert res.total_pages == 3
assert res.page == 1
assert res.page_size == 8
assert len(res.items) == 8
for i, r in enumerate(res.items):
assert r.id == i + 1
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.prompt_name == f"my_prompt_{i}"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
assert r.sys_code == "dbgpt"
res_half = dao.get_list_page({"user_name": "zhangsan"}, page=2, page_size=8)
assert res_half.total_count == 10
assert res_half.total_pages == 2
assert res_half.page == 2
assert res_half.page_size == 8
assert len(res_half.items) == 2
for i, r in enumerate(res_half.items):
assert r.chat_scene == "chat_data"
assert r.sub_chat_scene == "excel"
assert r.prompt_type == "common"
assert r.content == "Write a qsort function in python."
assert r.user_name == "zhangsan"
assert r.sys_code == "dbgpt"

View File

@@ -0,0 +1,154 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
return {
"chat_scene": "chat_data",
"sub_chat_scene": "excel",
"prompt_type": "common",
"prompt_name": "my_prompt_1",
"content": "Write a qsort function in python.",
"user_name": "zhangsan",
"sys_code": "dbgpt",
}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
@pytest.mark.parametrize(
"system_app",
[
{
"app_config": {
"DEBUG": True,
"dbgpt.serve.prompt.default_user": "dbgpt",
"dbgpt.serve.prompt.default_sys_code": "dbgpt",
}
}
],
indirect=True,
)
def test_config_default_user(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.prompt.default_user") == "dbgpt"
assert service.config is not None
assert service.config.default_user == "dbgpt"
assert service.config.default_sys_code == "dbgpt"
def test_service_create(service: Service, default_entity_dict):
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_update(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_get(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
assert db_entity.chat_scene == "chat_data"
assert db_entity.sub_chat_scene == "excel"
assert db_entity.prompt_type == "common"
assert db_entity.prompt_name == "my_prompt_1"
assert db_entity.content == "Write a qsort function in python."
assert db_entity.user_name == "zhangsan"
assert db_entity.sys_code == "dbgpt"
assert db_entity.gmt_created is not None
assert db_entity.gmt_modified is not None
def test_service_delete(service: Service, default_entity_dict):
service.create(ServeRequest(**default_entity_dict))
service.delete(ServeRequest(**default_entity_dict))
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
assert entity is None
def test_service_get_list(service: Service):
for i in range(3):
service.create(
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
)
entities: List[ServerResponse] = service.get_list(ServeRequest(sys_code="dbgpt"))
assert len(entities) == 3
for i, entity in enumerate(entities):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"
def test_service_get_list_by_page(service: Service):
for i in range(3):
service.create(
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
)
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
assert res is not None
assert res.total_count == 3
assert res.total_pages == 2
assert len(res.items) == 2
for i, entity in enumerate(res.items):
assert entity.sys_code == "dbgpt"
assert entity.prompt_name == f"prompt_{i}"

View File

@@ -1,5 +1,8 @@
from typing import Optional, List
from fastapi import APIRouter, Depends, Query
from functools import cache
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.serve.core import Result
@@ -20,13 +23,78 @@ def get_service() -> Service:
return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service)
get_bearer_token = HTTPBearer(auto_error=False)
@cache
def _parse_api_keys(api_keys: str) -> List[str]:
"""Parse the string api keys to a list
Args:
api_keys (str): The string api keys
Returns:
List[str]: The list of api keys
"""
if not api_keys:
return []
return [key.strip() for key in api_keys.split(",")]
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
service: Service = Depends(get_service),
) -> Optional[str]:
"""Check the api key
If the api key is not set, allow all.
Your can pass the token in you request header like this:
.. code-block:: python
import requests
client_api_key = "your_api_key"
headers = {"Authorization": "Bearer " + client_api_key }
res = requests.get("http://test/hello", headers=headers)
assert res.status_code == 200
"""
if service.config.api_keys:
api_keys = _parse_api_keys(service.config.api_keys)
if auth is None or (token := auth.credentials) not in api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
@router.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "ok"}
@router.post("/", response_model=Result[ServerResponse])
@router.get("/test_auth", dependencies=[Depends(check_api_key)])
async def test_auth():
"""Test auth endpoint"""
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -41,7 +109,9 @@ async def create(
return Result.succ(service.create(request))
@router.put("/", response_model=Result[ServerResponse])
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -56,7 +126,11 @@ async def update(
return Result.succ(service.update(request))
@router.post("/query", response_model=Result[ServerResponse])
@router.post(
"/query",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def query(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
@@ -71,7 +145,11 @@ async def query(
return Result.succ(service.get(request))
@router.post("/query_page", response_model=Result[PaginationResult[ServerResponse]])
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),

View File

@@ -1,5 +1,6 @@
# Define your Pydantic schemas here
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
@@ -7,8 +8,13 @@ class ServeRequest(BaseModel):
# TODO define your own fields here
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
class ServerResponse(BaseModel):
"""{__template_app_name__hump__} response model"""
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from dbgpt.serve.core import BaseServeConfig
@@ -17,3 +18,6 @@ class ServeConfig(BaseServeConfig):
"""Parameters for the serve command"""
# TODO: add your own parameters here
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)

View File

@@ -2,7 +2,13 @@ from typing import List, Optional
from dbgpt.component import BaseComponent, SystemApp
from .api.endpoints import router, init_endpoints
from .config import SERVE_APP_NAME, SERVE_APP_NAME_HUMP, APP_NAME
from .config import (
SERVE_APP_NAME,
SERVE_APP_NAME_HUMP,
APP_NAME,
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
class Serve(BaseComponent):

View File

@@ -13,10 +13,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp):
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = None
self._dao: ServeDao = dao
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@@ -28,7 +28,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._serve_config = ServeConfig.from_app_config(
system_app.config, SERVE_CONFIG_KEY_PREFIX
)
self._dao = ServeDao(self._serve_config)
self._dao = self._dao or ServeDao(self._serve_config)
self._system_app = system_app
@property

View File

@@ -0,0 +1,124 @@
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.util import PaginationResult
from ..config import SERVE_CONFIG_KEY_PREFIX
from ..api.endpoints import router, init_endpoints
from ..api.schemas import ServeRequest, ServerResponse
from dbgpt.serve.core.tests.conftest import client, asystem_app
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
def client_init_caller(app: FastAPI, system_app: SystemApp):
app.include_router(router)
init_endpoints(system_app)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client, asystem_app, has_auth",
[
(
{
"app_caller": client_init_caller,
"client_api_key": "test_token1",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
True,
),
(
{
"app_caller": client_init_caller,
"client_api_key": "error_token",
},
{
"app_config": {
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
}
},
False,
),
],
indirect=["client", "asystem_app"],
)
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
response = await client.get("/test_auth")
if has_auth:
assert response.status_code == 200
assert response.json() == {"status": "ok"}
else:
assert response.status_code == 401
assert response.json() == {
"detail": {
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
}
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_health(client: AsyncClient):
response = await client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_create(client: AsyncClient):
# TODO: add your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_update(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query(client: AsyncClient):
# TODO: implement your test case
pass
@pytest.mark.asyncio
@pytest.mark.parametrize(
"client", [{"app_caller": client_init_caller}], indirect=["client"]
)
async def test_api_query_by_page(client: AsyncClient):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,109 @@
from typing import List
import pytest
from dbgpt.storage.metadata import db
from ..config import ServeConfig
from ..api.schemas import ServeRequest, ServerResponse
from ..models.models import ServeEntity, ServeDao
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def server_config():
# TODO : build your server config
return ServeConfig()
@pytest.fixture
def dao(server_config):
return ServeDao(server_config)
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
def test_table_exist():
assert ServeEntity.__tablename__ in db.metadata.tables
def test_entity_create(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
# TODO: implement your test case
with db.session() as session:
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
assert db_entity.id == entity.id
def test_entity_unique_key(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_get(default_entity_dict):
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity.id == entity.id
# TODO: implement your test case
def test_entity_update(default_entity_dict):
# TODO: implement your test case
pass
def test_entity_delete(default_entity_dict):
# TODO: implement your test case
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
entity.delete()
db_entity: ServeEntity = ServeEntity.get(entity.id)
assert db_entity is None
def test_entity_all():
# TODO: implement your test case
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass
def test_dao_update(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_delete(dao, default_entity_dict):
# TODO: implement your test case
pass
def test_dao_get_list_page(dao):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -0,0 +1,76 @@
from typing import List
import pytest
from dbgpt.component import SystemApp
from dbgpt.storage.metadata import db
from dbgpt.serve.core.tests.conftest import system_app
from ..models.models import ServeEntity
from ..api.schemas import ServeRequest, ServerResponse
from ..service.service import Service
@pytest.fixture(autouse=True)
def setup_and_teardown():
db.init_db("sqlite:///:memory:")
db.create_all()
yield
@pytest.fixture
def service(system_app: SystemApp):
instance = Service(system_app)
instance.init_app(system_app)
return instance
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
@pytest.mark.parametrize(
"system_app",
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
indirect=True,
)
def test_config_exists(service: Service):
system_app: SystemApp = service._system_app
assert system_app.config.get("DEBUG") is True
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
assert service.config is not None
def test_service_create(service: Service, default_entity_dict):
# TODO: implement your test case
# eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
# ...
pass
def test_service_update(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_delete(service: Service, default_entity_dict):
# TODO: implement your test case
pass
def test_service_get_list(service: Service):
# TODO: implement your test case
pass
def test_service_get_list_by_page(service: Service):
# TODO: implement your test case
pass
# Add more test cases according to your own logic

View File

@@ -62,6 +62,7 @@ def replace_template_variables(content: str, app_name: str):
def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
for root, dirs, files in os.walk(src_dir):
dirs[:] = [d for d in dirs if not _should_ignore(d)]
relative_path = os.path.relpath(root, src_dir)
if relative_path == ".":
relative_path = ""
@@ -70,6 +71,8 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
os.makedirs(target_dir, exist_ok=True)
for file in files:
if _should_ignore(file):
continue
try:
with open(os.path.join(root, file), "r") as f:
content = f.read()
@@ -81,3 +84,9 @@ def copy_template_files(src_dir: str, dst_dir: str, app_name: str):
except Exception as e:
click.echo(f"Error copying file {file} from {src_dir} to {dst_dir}")
raise e
def _should_ignore(file_or_dir: str):
"""Return True if the given file or directory should be ignored.""" ""
ignore_patterns = [".pyc", "__pycache__"]
return any(pattern in file_or_dir for pattern in ignore_patterns)