DB-GPT/dbgpt/storage/metadata/tests/test_base_dao.py
2024-04-20 09:41:16 +08:00

178 lines
5.4 KiB
Python

from typing import Any, Dict, Optional, Type, Union
import pytest
from sqlalchemy import Column, Integer, String
from dbgpt._private.pydantic import BaseModel as PydanticBaseModel
from dbgpt._private.pydantic import Field, model_to_dict
from dbgpt.storage.metadata.db_manager import (
BaseModel,
DatabaseManager,
PaginationResult,
create_model,
)
from .._base_dao import BaseDao
class UserRequest(PydanticBaseModel):
name: str = Field(..., description="User name")
age: Optional[int] = Field(default=-1, description="User age")
password: Optional[str] = Field(default="", description="User password")
class UserResponse(PydanticBaseModel):
id: int = Field(..., description="User id")
name: str = Field(..., description="User name")
age: Optional[int] = Field(default=-1, description="User age")
@pytest.fixture
def db():
db = DatabaseManager()
db.init_db("sqlite:///:memory:")
return db
@pytest.fixture
def Model(db):
return create_model(db)
@pytest.fixture
def User(Model):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50), unique=True)
age = Column(Integer)
password = Column(String(50))
return User
@pytest.fixture
def user_req():
return UserRequest(name="Edward Snowden", age=30, password="123456")
@pytest.fixture
def user_dao(db, User):
class UserDao(BaseDao[User, UserRequest, UserResponse]):
def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User:
if isinstance(request, UserRequest):
return User(**model_to_dict(request))
else:
return User(**request)
def to_request(self, entity: User) -> UserRequest:
return UserRequest(
name=entity.name, age=entity.age, password=entity.password
)
def from_response(self, response: UserResponse) -> User:
return User(**model_to_dict(response))
def to_response(self, entity: User):
return UserResponse(id=entity.id, name=entity.name, age=entity.age)
db.create_all()
return UserDao(db)
def test_create_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
user_dao.create(user_req)
with db.session() as session:
user = session.query(User).first()
assert user.name == user_req.name
assert user.age == user_req.age
assert user.password == user_req.password
def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
# Create a user
created_user_response = user_dao.create(user_req)
# Update the user
updated_req = UserRequest(name=user_req.name, age=35, password="newpassword")
updated_user = user_dao.update(
query_request={"name": user_req.name}, update_request=updated_req
)
assert updated_user.id == created_user_response.id
assert updated_user.age == 35
# Verify that the user is updated in the database
with db.session() as session:
user = session.get(User, created_user_response.id)
assert user.age == 35
def test_update_user_partial(
db: DatabaseManager, User: Type[BaseModel], user_dao, user_req
):
# Create a user
created_user_response = user_dao.create(user_req)
# Update the user
updated_req = UserRequest(name=user_req.name, password="newpassword")
updated_req.age = None
updated_user = user_dao.update(
query_request={"name": user_req.name}, update_request=updated_req
)
assert updated_user.id == created_user_response.id
assert updated_user.age == user_req.age
# Verify that the user is updated in the database
with db.session() as session:
user = session.get(User, created_user_response.id)
assert user.age == user_req.age
assert user.password == "newpassword"
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
# Create a user
created_user_response = user_dao.create(user_req)
# Query the user
fetched_user = user_dao.get_one({"name": user_req.name})
assert fetched_user.id == created_user_response.id
assert fetched_user.name == user_req.name
assert fetched_user.age == user_req.age
def test_get_list_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
for i in range(20):
user_dao.create(
UserRequest(
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
)
)
# Query the user
fetched_user = user_dao.get_list({"password": "123456"})
assert len(fetched_user) == 10
def test_get_list_page_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
for i in range(20):
user_dao.create(
UserRequest(
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
)
)
page_result: PaginationResult = user_dao.get_list_page(
{"password": "123456"}, page=1, page_size=3
)
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert len(page_result.items) == 3
assert page_result.items[0].name == "User 0"
# Test query next page
page_result: PaginationResult = user_dao.get_list_page(
{"password": "123456"}, page=2, page_size=3
)
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert len(page_result.items) == 3
assert page_result.items[0].name == "User 6"