refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

@@ -1,18 +1,27 @@
from contextlib import contextmanager
from typing import TypeVar, Generic, Any, Optional
from typing import TypeVar, Generic, Any, Optional, Dict, Union, List
from sqlalchemy.orm.session import Session
from dbgpt.util.pagination_utils import PaginationResult
# The entity type
T = TypeVar("T")
# The request schema type
REQ = TypeVar("REQ")
# The response schema type
RES = TypeVar("RES")
from .db_manager import db, DatabaseManager
from .db_manager import db, DatabaseManager, BaseQuery
class BaseDao(Generic[T]):
QUERY_SPEC = Union[REQ, Dict[str, Any]]
class BaseDao(Generic[T, REQ, RES]):
"""The base class for all DAOs.
Examples:
.. code-block:: python
class UserDao(BaseDao[User]):
class UserDao(BaseDao):
def get_user_by_name(self, name: str) -> User:
with self.session() as session:
return session.query(User).filter(User.name == name).first()
@@ -70,3 +79,184 @@ class BaseDao(Generic[T]):
"""
with self._db_manager.session() as session:
yield session
def from_request(self, request: QUERY_SPEC) -> T:
"""Convert a request schema object to an entity object.
Args:
request (REQ): The request schema object or dict for query.
Returns:
T: The entity object.
"""
raise NotImplementedError
def to_request(self, entity: T) -> REQ:
"""Convert an entity object to a request schema object.
Args:
entity (T): The entity object.
Returns:
REQ: The request schema object.
"""
raise NotImplementedError
def from_response(self, response: RES) -> T:
"""Convert a response schema object to an entity object.
Args:
response (RES): The response schema object.
Returns:
T: The entity object.
"""
raise NotImplementedError
def to_response(self, entity: T) -> RES:
"""Convert an entity object to a response schema object.
Args:
entity (T): The entity object.
Returns:
RES: The response schema object.
"""
raise NotImplementedError
def create(self, request: REQ) -> RES:
"""Create an entity object.
Args:
request (REQ): The request schema object.
Returns:
RES: The response schema object.
"""
entry = self.from_request(request)
with self.session() as session:
session.add(entry)
return self.get_one(self.to_request(entry))
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
"""Update an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
update_request (REQ): The request schema object for update.
Returns:
RES: The response schema object.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
entry = query.first()
if entry is None:
raise Exception("Invalid request")
for key, value in update_request.dict().items():
setattr(entry, key, value)
session.merge(entry)
return self.get_one(self.to_request(entry))
def delete(self, query_request: QUERY_SPEC) -> None:
"""Delete an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
"""
with self.session() as session:
result_list = self._get_entity_list(session, query_request)
if len(result_list) != 1:
raise ValueError(
f"Delete request should return one result, but got {len(result_list)}"
)
session.delete(result_list[0])
def get_one(self, query_request: QUERY_SPEC) -> Optional[RES]:
"""Get an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
Returns:
Optional[RES]: The response schema object.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
result = query.first()
if result is None:
return None
return self.to_response(result)
def get_list(self, query_request: QUERY_SPEC) -> List[RES]:
"""Get a list of entity objects.
Args:
query_request (REQ): The request schema object or dict for query.
Returns:
List[RES]: The response schema object.
"""
with self.session() as session:
result_list = self._get_entity_list(session, query_request)
return [self.to_response(item) for item in result_list]
def _get_entity_list(self, session: Session, query_request: QUERY_SPEC) -> List[T]:
"""Get a list of entity objects.
Args:
session (Session): The session object.
query_request (REQ): The request schema object or dict for query.
Returns:
List[RES]: The response schema object.
"""
query = self._create_query_object(session, query_request)
result_list = query.all()
return result_list
def get_list_page(
self, query_request: QUERY_SPEC, page: int, page_size: int
) -> PaginationResult[RES]:
"""Get a page of entity objects.
Args:
query_request (REQ): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
Returns:
PaginationResult: The pagination result.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
return PaginationResult(
items=items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
def _create_query_object(
self, session: Session, query_request: QUERY_SPEC
) -> BaseQuery:
"""Create a query object.
Args:
session (Session): The session object.
query_request (QUERY_SPEC): The request schema object or dict for query.
Returns:
BaseQuery: The query object.
"""
model_cls = type(self.from_request(query_request))
query = session.query(model_cls)
query_dict = (
query_request if isinstance(query_request, dict) else query_request.dict()
)
for key, value in query_dict.items():
if value is not None:
query = query.filter(getattr(model_cls, key) == value)
return query

View File

@@ -103,7 +103,8 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
if value is not None:
query = query.filter(getattr(self._model_class, key) == value)
if spec.limit is not None:
query = query.limit(spec.limit)
if spec.offset is not None:
@@ -124,5 +125,6 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
if value is not None:
query = query.filter(getattr(self._model_class, key) == value)
return query.count()

View File

@@ -0,0 +1,152 @@
from typing import Type, Optional, Union, Dict, Any
import pytest
from sqlalchemy import Column, Integer, String
from dbgpt._private.pydantic import BaseModel as PydanticBaseModel, Field
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
PaginationResult,
create_model,
BaseModel,
)
from .._base_dao import BaseDao
class UserRequest(PydanticBaseModel):
name: str = Field(..., description="User name")
age: Optional[int] = Field(default=-1, description="User age")
password: Optional[str] = Field(default="", description="User password")
class UserResponse(PydanticBaseModel):
id: int = Field(..., description="User id")
name: str = Field(..., description="User name")
age: Optional[int] = Field(default=-1, description="User age")
@pytest.fixture
def db():
db = DatabaseManager()
db.init_db("sqlite:///:memory:")
return db
@pytest.fixture
def Model(db):
return create_model(db)
@pytest.fixture
def User(Model):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50), unique=True)
age = Column(Integer)
password = Column(String(50))
return User
@pytest.fixture
def user_req():
return UserRequest(name="Edward Snowden", age=30, password="123456")
@pytest.fixture
def user_dao(db, User):
class UserDao(BaseDao[User, UserRequest, UserResponse]):
def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User:
if isinstance(request, UserRequest):
return User(**request.dict())
else:
return User(**request)
def to_request(self, entity: User) -> UserRequest:
return UserRequest(
name=entity.name, age=entity.age, password=entity.password
)
def from_response(self, response: UserResponse) -> User:
return User(**response.dict())
def to_response(self, entity: User):
return UserResponse(id=entity.id, name=entity.name, age=entity.age)
db.create_all()
return UserDao(db)
def test_create_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
user_dao.create(user_req)
with db.session() as session:
user = session.query(User).first()
assert user.name == user_req.name
assert user.age == user_req.age
assert user.password == user_req.password
def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
# Create a user
created_user_response = user_dao.create(user_req)
# Update the user
updated_req = UserRequest(name=user_req.name, age=35, password="newpassword")
updated_user = user_dao.update(
query_request={"name": user_req.name}, update_request=updated_req
)
assert updated_user.id == created_user_response.id
assert updated_user.age == 35
# Verify that the user is updated in the database
with db.session() as session:
user = session.query(User).get(created_user_response.id)
assert user.age == 35
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
# Create a user
created_user_response = user_dao.create(user_req)
# Query the user
fetched_user = user_dao.get_one({"name": user_req.name})
assert fetched_user.id == created_user_response.id
assert fetched_user.name == user_req.name
assert fetched_user.age == user_req.age
def test_get_list_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
for i in range(20):
user_dao.create(
UserRequest(
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
)
)
# Query the user
fetched_user = user_dao.get_list({"password": "123456"})
assert len(fetched_user) == 10
def test_get_list_page_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
for i in range(20):
user_dao.create(
UserRequest(
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
)
)
page_result: PaginationResult = user_dao.get_list_page(
{"password": "123456"}, page=1, page_size=3
)
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert len(page_result.items) == 3
assert page_result.items[0].name == "User 0"
# Test query next page
page_result: PaginationResult = user_dao.get_list_page(
{"password": "123456"}, page=2, page_size=3
)
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert len(page_result.items) == 3
assert page_result.items[0].name == "User 6"