mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
feat(core): Add API authentication for serve template (#950)
This commit is contained in:
0
dbgpt/serve/prompt/tests/__init__.py
Normal file
0
dbgpt/serve/prompt/tests/__init__.py
Normal file
176
dbgpt/serve/prompt/tests/test_endpoints.py
Normal file
176
dbgpt/serve/prompt/tests/test_endpoints.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.util import PaginationResult
|
||||
from ..config import SERVE_CONFIG_KEY_PREFIX
|
||||
from ..api.endpoints import router, init_endpoints
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
|
||||
from dbgpt.serve.core.tests.conftest import client, asystem_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def client_init_caller(app: FastAPI, system_app: SystemApp):
|
||||
app.include_router(router)
|
||||
init_endpoints(system_app)
|
||||
|
||||
|
||||
async def _create_and_validate(
|
||||
client: AsyncClient, sys_code: str, content: str, expect_id: int = 1, **kwargs
|
||||
):
|
||||
req_json = {"sys_code": sys_code, "content": content}
|
||||
req_json.update(kwargs)
|
||||
response = await client.post("/add", json=req_json)
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
res_obj = ServerResponse(**data)
|
||||
assert res_obj.id == expect_id
|
||||
assert res_obj.sys_code == sys_code
|
||||
assert res_obj.content == content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client, asystem_app, has_auth",
|
||||
[
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "test_token1",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
{
|
||||
"app_caller": client_init_caller,
|
||||
"client_api_key": "error_token",
|
||||
},
|
||||
{
|
||||
"app_config": {
|
||||
f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2"
|
||||
}
|
||||
},
|
||||
False,
|
||||
),
|
||||
],
|
||||
indirect=["client", "asystem_app"],
|
||||
)
|
||||
async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool):
|
||||
response = await client.get("/test_auth")
|
||||
if has_auth:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
else:
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {
|
||||
"detail": {
|
||||
"error": {
|
||||
"message": "",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_auth(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_create(client: AsyncClient):
|
||||
await _create_and_validate(client, "test", "test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_update(client: AsyncClient):
|
||||
await _create_and_validate(client, "test", "test")
|
||||
|
||||
response = await client.post("/update", json={"id": 1, "content": "test2"})
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
res_obj = ServerResponse(**data)
|
||||
assert res_obj.id == 1
|
||||
assert res_obj.sys_code == "test"
|
||||
assert res_obj.content == "test2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query(client: AsyncClient):
|
||||
for i in range(10):
|
||||
await _create_and_validate(
|
||||
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
|
||||
)
|
||||
response = await client.post("/list", json={"sys_code": "test"})
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
assert len(data) == 10
|
||||
res_obj = ServerResponse(**data[0])
|
||||
assert res_obj.id == 1
|
||||
assert res_obj.sys_code == "test"
|
||||
assert res_obj.content == "test0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"client", [{"app_caller": client_init_caller}], indirect=["client"]
|
||||
)
|
||||
async def test_api_query_by_page(client: AsyncClient):
|
||||
for i in range(10):
|
||||
await _create_and_validate(
|
||||
client, "test", f"test{i}", expect_id=i + 1, prompt_name=f"prompt_name_{i}"
|
||||
)
|
||||
response = await client.post(
|
||||
"/query_page", params={"page": 1, "page_size": 5}, json={"sys_code": "test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
json_res = response.json()
|
||||
assert "success" in json_res and json_res["success"]
|
||||
assert "data" in json_res and json_res["data"]
|
||||
data = json_res["data"]
|
||||
page_result: PaginationResult = PaginationResult(**data)
|
||||
assert page_result.total_count == 10
|
||||
assert page_result.total_pages == 2
|
||||
assert page_result.page == 1
|
||||
assert page_result.page_size == 5
|
||||
assert len(page_result.items) == 5
|
257
dbgpt/serve/prompt/tests/test_models.py
Normal file
257
dbgpt/serve/prompt/tests/test_models.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.storage.metadata import db
|
||||
from ..config import ServeConfig
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..models.models import ServeEntity, ServeDao
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
return ServeConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dao(server_config):
|
||||
return ServeDao(server_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
return {
|
||||
"chat_scene": "chat_data",
|
||||
"sub_chat_scene": "excel",
|
||||
"prompt_type": "common",
|
||||
"prompt_name": "my_prompt_1",
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
}
|
||||
|
||||
|
||||
def test_table_exist():
|
||||
assert ServeEntity.__tablename__ in db.metadata.tables
|
||||
|
||||
|
||||
def test_entity_create(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_unique_key(default_entity_dict):
|
||||
ServeEntity.create(**default_entity_dict)
|
||||
with pytest.raises(Exception):
|
||||
ServeEntity.create(**{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
|
||||
|
||||
def test_entity_get(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_update(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.update(prompt_name="my_prompt_2")
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_2"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_entity_delete(default_entity_dict):
|
||||
entity: ServeEntity = ServeEntity.create(**default_entity_dict)
|
||||
entity.delete()
|
||||
db_entity: ServeEntity = ServeEntity.get(entity.id)
|
||||
assert db_entity is None
|
||||
|
||||
|
||||
def test_entity_all():
|
||||
for i in range(10):
|
||||
ServeEntity.create(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
entities = ServeEntity.all()
|
||||
assert len(entities) == 10
|
||||
for entity in entities:
|
||||
assert entity.chat_scene == "chat_data"
|
||||
assert entity.sub_chat_scene == "excel"
|
||||
assert entity.prompt_type == "common"
|
||||
assert entity.content == "Write a qsort function in python."
|
||||
assert entity.user_name == "zhangsan"
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.gmt_created is not None
|
||||
assert entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_dao_create(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_1"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_dao_get_one(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
res: ServerResponse = dao.get_one(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
|
||||
)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_1"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_get_dao_get_list(dao):
|
||||
for i in range(10):
|
||||
dao.create(
|
||||
ServeRequest(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan" if i % 2 == 0 else "lisi",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
)
|
||||
res: List[ServerResponse] = dao.get_list({"sys_code": "dbgpt"})
|
||||
assert len(res) == 10
|
||||
for i, r in enumerate(res):
|
||||
assert r.id == i + 1
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.prompt_name == f"my_prompt_{i}"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
|
||||
assert r.sys_code == "dbgpt"
|
||||
|
||||
half_res: List[ServerResponse] = dao.get_list({"user_name": "zhangsan"})
|
||||
assert len(half_res) == 5
|
||||
|
||||
|
||||
def test_dao_update(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
res: ServerResponse = dao.update(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"},
|
||||
ServeRequest(prompt_name="my_prompt_2"),
|
||||
)
|
||||
assert res is not None
|
||||
assert res.id == 1
|
||||
assert res.chat_scene == "chat_data"
|
||||
assert res.sub_chat_scene == "excel"
|
||||
assert res.prompt_type == "common"
|
||||
assert res.prompt_name == "my_prompt_2"
|
||||
assert res.content == "Write a qsort function in python."
|
||||
assert res.user_name == "zhangsan"
|
||||
assert res.sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_dao_delete(dao, default_entity_dict):
|
||||
req = ServeRequest(**default_entity_dict)
|
||||
res: ServerResponse = dao.create(req)
|
||||
dao.delete({"prompt_name": "my_prompt_1", "sys_code": "dbgpt"})
|
||||
res: ServerResponse = dao.get_one(
|
||||
{"prompt_name": "my_prompt_1", "sys_code": "dbgpt"}
|
||||
)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_dao_get_list_page(dao):
|
||||
for i in range(20):
|
||||
dao.create(
|
||||
ServeRequest(
|
||||
chat_scene="chat_data",
|
||||
sub_chat_scene="excel",
|
||||
prompt_type="common",
|
||||
prompt_name=f"my_prompt_{i}",
|
||||
content="Write a qsort function in python.",
|
||||
user_name="zhangsan" if i % 2 == 0 else "lisi",
|
||||
sys_code="dbgpt",
|
||||
)
|
||||
)
|
||||
res = dao.get_list_page({"sys_code": "dbgpt"}, page=1, page_size=8)
|
||||
assert res.total_count == 20
|
||||
assert res.total_pages == 3
|
||||
assert res.page == 1
|
||||
assert res.page_size == 8
|
||||
assert len(res.items) == 8
|
||||
for i, r in enumerate(res.items):
|
||||
assert r.id == i + 1
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.prompt_name == f"my_prompt_{i}"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan" if i % 2 == 0 else "lisi"
|
||||
assert r.sys_code == "dbgpt"
|
||||
|
||||
res_half = dao.get_list_page({"user_name": "zhangsan"}, page=2, page_size=8)
|
||||
assert res_half.total_count == 10
|
||||
assert res_half.total_pages == 2
|
||||
assert res_half.page == 2
|
||||
assert res_half.page_size == 8
|
||||
assert len(res_half.items) == 2
|
||||
for i, r in enumerate(res_half.items):
|
||||
assert r.chat_scene == "chat_data"
|
||||
assert r.sub_chat_scene == "excel"
|
||||
assert r.prompt_type == "common"
|
||||
assert r.content == "Write a qsort function in python."
|
||||
assert r.user_name == "zhangsan"
|
||||
assert r.sys_code == "dbgpt"
|
154
dbgpt/serve/prompt/tests/test_service.py
Normal file
154
dbgpt/serve/prompt/tests/test_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from typing import List
|
||||
import pytest
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.storage.metadata import db
|
||||
from dbgpt.serve.core.tests.conftest import system_app
|
||||
|
||||
from ..models.models import ServeEntity
|
||||
from ..api.schemas import ServeRequest, ServerResponse
|
||||
from ..service.service import Service
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_and_teardown():
|
||||
db.init_db("sqlite:///:memory:")
|
||||
db.create_all()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(system_app: SystemApp):
|
||||
instance = Service(system_app)
|
||||
instance.init_app(system_app)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_entity_dict():
|
||||
return {
|
||||
"chat_scene": "chat_data",
|
||||
"sub_chat_scene": "excel",
|
||||
"prompt_type": "common",
|
||||
"prompt_name": "my_prompt_1",
|
||||
"content": "Write a qsort function in python.",
|
||||
"user_name": "zhangsan",
|
||||
"sys_code": "dbgpt",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_exists(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.test_key") == "hello"
|
||||
assert service.config is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"system_app",
|
||||
[
|
||||
{
|
||||
"app_config": {
|
||||
"DEBUG": True,
|
||||
"dbgpt.serve.prompt.default_user": "dbgpt",
|
||||
"dbgpt.serve.prompt.default_sys_code": "dbgpt",
|
||||
}
|
||||
}
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_config_default_user(service: Service):
|
||||
system_app: SystemApp = service._system_app
|
||||
assert system_app.config.get("DEBUG") is True
|
||||
assert system_app.config.get("dbgpt.serve.prompt.default_user") == "dbgpt"
|
||||
assert service.config is not None
|
||||
assert service.config.default_user == "dbgpt"
|
||||
assert service.config.default_sys_code == "dbgpt"
|
||||
|
||||
|
||||
def test_service_create(service: Service, default_entity_dict):
|
||||
entity: ServerResponse = service.create(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_update(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.update(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_get(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||
with db.session() as session:
|
||||
db_entity: ServeEntity = session.query(ServeEntity).get(entity.id)
|
||||
assert db_entity.id == entity.id
|
||||
assert db_entity.chat_scene == "chat_data"
|
||||
assert db_entity.sub_chat_scene == "excel"
|
||||
assert db_entity.prompt_type == "common"
|
||||
assert db_entity.prompt_name == "my_prompt_1"
|
||||
assert db_entity.content == "Write a qsort function in python."
|
||||
assert db_entity.user_name == "zhangsan"
|
||||
assert db_entity.sys_code == "dbgpt"
|
||||
assert db_entity.gmt_created is not None
|
||||
assert db_entity.gmt_modified is not None
|
||||
|
||||
|
||||
def test_service_delete(service: Service, default_entity_dict):
|
||||
service.create(ServeRequest(**default_entity_dict))
|
||||
service.delete(ServeRequest(**default_entity_dict))
|
||||
entity: ServerResponse = service.get(ServeRequest(**default_entity_dict))
|
||||
assert entity is None
|
||||
|
||||
|
||||
def test_service_get_list(service: Service):
|
||||
for i in range(3):
|
||||
service.create(
|
||||
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
|
||||
)
|
||||
entities: List[ServerResponse] = service.get_list(ServeRequest(sys_code="dbgpt"))
|
||||
assert len(entities) == 3
|
||||
for i, entity in enumerate(entities):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
||||
|
||||
|
||||
def test_service_get_list_by_page(service: Service):
|
||||
for i in range(3):
|
||||
service.create(
|
||||
ServeRequest(**{"prompt_name": f"prompt_{i}", "sys_code": "dbgpt"})
|
||||
)
|
||||
res = service.get_list_by_page(ServeRequest(sys_code="dbgpt"), page=1, page_size=2)
|
||||
assert res is not None
|
||||
assert res.total_count == 3
|
||||
assert res.total_pages == 2
|
||||
assert len(res.items) == 2
|
||||
for i, entity in enumerate(res.items):
|
||||
assert entity.sys_code == "dbgpt"
|
||||
assert entity.prompt_name == f"prompt_{i}"
|
Reference in New Issue
Block a user