mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 20:28:07 +00:00
178 lines
5.4 KiB
Python
178 lines
5.4 KiB
Python
from typing import Any, Dict, Optional, Type, Union
|
|
|
|
import pytest
|
|
from sqlalchemy import Column, Integer, String
|
|
|
|
from dbgpt._private.pydantic import BaseModel as PydanticBaseModel
|
|
from dbgpt._private.pydantic import Field, model_to_dict
|
|
from dbgpt.storage.metadata.db_manager import (
|
|
BaseModel,
|
|
DatabaseManager,
|
|
PaginationResult,
|
|
create_model,
|
|
)
|
|
|
|
from .._base_dao import BaseDao
|
|
|
|
|
|
class UserRequest(PydanticBaseModel):
|
|
name: str = Field(..., description="User name")
|
|
age: Optional[int] = Field(default=-1, description="User age")
|
|
password: Optional[str] = Field(default="", description="User password")
|
|
|
|
|
|
class UserResponse(PydanticBaseModel):
|
|
id: int = Field(..., description="User id")
|
|
name: str = Field(..., description="User name")
|
|
age: Optional[int] = Field(default=-1, description="User age")
|
|
|
|
|
|
@pytest.fixture
|
|
def db():
|
|
db = DatabaseManager()
|
|
db.init_db("sqlite:///:memory:")
|
|
return db
|
|
|
|
|
|
@pytest.fixture
|
|
def Model(db):
|
|
return create_model(db)
|
|
|
|
|
|
@pytest.fixture
|
|
def User(Model):
|
|
class User(Model):
|
|
__tablename__ = "user"
|
|
id = Column(Integer, primary_key=True)
|
|
name = Column(String(50), unique=True)
|
|
age = Column(Integer)
|
|
password = Column(String(50))
|
|
|
|
return User
|
|
|
|
|
|
@pytest.fixture
|
|
def user_req():
|
|
return UserRequest(name="Edward Snowden", age=30, password="123456")
|
|
|
|
|
|
@pytest.fixture
|
|
def user_dao(db, User):
|
|
class UserDao(BaseDao[User, UserRequest, UserResponse]):
|
|
def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User:
|
|
if isinstance(request, UserRequest):
|
|
return User(**model_to_dict(request))
|
|
else:
|
|
return User(**request)
|
|
|
|
def to_request(self, entity: User) -> UserRequest:
|
|
return UserRequest(
|
|
name=entity.name, age=entity.age, password=entity.password
|
|
)
|
|
|
|
def from_response(self, response: UserResponse) -> User:
|
|
return User(**model_to_dict(response))
|
|
|
|
def to_response(self, entity: User):
|
|
return UserResponse(id=entity.id, name=entity.name, age=entity.age)
|
|
|
|
db.create_all()
|
|
return UserDao(db)
|
|
|
|
|
|
def test_create_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
|
user_dao.create(user_req)
|
|
with db.session() as session:
|
|
user = session.query(User).first()
|
|
assert user.name == user_req.name
|
|
assert user.age == user_req.age
|
|
assert user.password == user_req.password
|
|
|
|
|
|
def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
|
# Create a user
|
|
created_user_response = user_dao.create(user_req)
|
|
|
|
# Update the user
|
|
updated_req = UserRequest(name=user_req.name, age=35, password="newpassword")
|
|
updated_user = user_dao.update(
|
|
query_request={"name": user_req.name}, update_request=updated_req
|
|
)
|
|
assert updated_user.id == created_user_response.id
|
|
assert updated_user.age == 35
|
|
|
|
# Verify that the user is updated in the database
|
|
with db.session() as session:
|
|
user = session.get(User, created_user_response.id)
|
|
assert user.age == 35
|
|
|
|
|
|
def test_update_user_partial(
|
|
db: DatabaseManager, User: Type[BaseModel], user_dao, user_req
|
|
):
|
|
# Create a user
|
|
created_user_response = user_dao.create(user_req)
|
|
|
|
# Update the user
|
|
updated_req = UserRequest(name=user_req.name, password="newpassword")
|
|
updated_req.age = None
|
|
updated_user = user_dao.update(
|
|
query_request={"name": user_req.name}, update_request=updated_req
|
|
)
|
|
assert updated_user.id == created_user_response.id
|
|
assert updated_user.age == user_req.age
|
|
|
|
# Verify that the user is updated in the database
|
|
with db.session() as session:
|
|
user = session.get(User, created_user_response.id)
|
|
assert user.age == user_req.age
|
|
assert user.password == "newpassword"
|
|
|
|
|
|
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
|
# Create a user
|
|
created_user_response = user_dao.create(user_req)
|
|
|
|
# Query the user
|
|
fetched_user = user_dao.get_one({"name": user_req.name})
|
|
assert fetched_user.id == created_user_response.id
|
|
assert fetched_user.name == user_req.name
|
|
assert fetched_user.age == user_req.age
|
|
|
|
|
|
def test_get_list_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
|
|
for i in range(20):
|
|
user_dao.create(
|
|
UserRequest(
|
|
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
|
|
)
|
|
)
|
|
# Query the user
|
|
fetched_user = user_dao.get_list({"password": "123456"})
|
|
assert len(fetched_user) == 10
|
|
|
|
|
|
def test_get_list_page_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
|
|
for i in range(20):
|
|
user_dao.create(
|
|
UserRequest(
|
|
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
|
|
)
|
|
)
|
|
page_result: PaginationResult = user_dao.get_list_page(
|
|
{"password": "123456"}, page=1, page_size=3
|
|
)
|
|
assert page_result.total_count == 10
|
|
assert page_result.total_pages == 4
|
|
assert len(page_result.items) == 3
|
|
assert page_result.items[0].name == "User 0"
|
|
|
|
# Test query next page
|
|
page_result: PaginationResult = user_dao.get_list_page(
|
|
{"password": "123456"}, page=2, page_size=3
|
|
)
|
|
assert page_result.total_count == 10
|
|
assert page_result.total_pages == 4
|
|
assert len(page_result.items) == 3
|
|
assert page_result.items[0].name == "User 6"
|