refactor: Refactor storage and new serve template (#947)

This commit is contained in:
Fangyin Cheng
2023-12-18 19:30:40 +08:00
committed by GitHub
parent 22d95b444b
commit 511a43b849
63 changed files with 1891 additions and 229 deletions

View File

@@ -1,18 +1,27 @@
from contextlib import contextmanager
from typing import TypeVar, Generic, Any, Optional
from typing import TypeVar, Generic, Any, Optional, Dict, Union, List
from sqlalchemy.orm.session import Session
from dbgpt.util.pagination_utils import PaginationResult
# The entity type
T = TypeVar("T")
# The request schema type
REQ = TypeVar("REQ")
# The response schema type
RES = TypeVar("RES")
from .db_manager import db, DatabaseManager
from .db_manager import db, DatabaseManager, BaseQuery
class BaseDao(Generic[T]):
QUERY_SPEC = Union[REQ, Dict[str, Any]]
class BaseDao(Generic[T, REQ, RES]):
"""The base class for all DAOs.
Examples:
.. code-block:: python
class UserDao(BaseDao[User]):
class UserDao(BaseDao):
def get_user_by_name(self, name: str) -> User:
with self.session() as session:
return session.query(User).filter(User.name == name).first()
@@ -70,3 +79,184 @@ class BaseDao(Generic[T]):
"""
with self._db_manager.session() as session:
yield session
def from_request(self, request: QUERY_SPEC) -> T:
"""Convert a request schema object to an entity object.
Args:
request (REQ): The request schema object or dict for query.
Returns:
T: The entity object.
"""
raise NotImplementedError
def to_request(self, entity: T) -> REQ:
"""Convert an entity object to a request schema object.
Args:
entity (T): The entity object.
Returns:
REQ: The request schema object.
"""
raise NotImplementedError
def from_response(self, response: RES) -> T:
"""Convert a response schema object to an entity object.
Args:
response (RES): The response schema object.
Returns:
T: The entity object.
"""
raise NotImplementedError
def to_response(self, entity: T) -> RES:
"""Convert an entity object to a response schema object.
Args:
entity (T): The entity object.
Returns:
RES: The response schema object.
"""
raise NotImplementedError
def create(self, request: REQ) -> RES:
"""Create an entity object.
Args:
request (REQ): The request schema object.
Returns:
RES: The response schema object.
"""
entry = self.from_request(request)
with self.session() as session:
session.add(entry)
return self.get_one(self.to_request(entry))
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
"""Update an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
update_request (REQ): The request schema object for update.
Returns:
RES: The response schema object.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
entry = query.first()
if entry is None:
raise Exception("Invalid request")
for key, value in update_request.dict().items():
setattr(entry, key, value)
session.merge(entry)
return self.get_one(self.to_request(entry))
def delete(self, query_request: QUERY_SPEC) -> None:
"""Delete an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
"""
with self.session() as session:
result_list = self._get_entity_list(session, query_request)
if len(result_list) != 1:
raise ValueError(
f"Delete request should return one result, but got {len(result_list)}"
)
session.delete(result_list[0])
def get_one(self, query_request: QUERY_SPEC) -> Optional[RES]:
"""Get an entity object.
Args:
query_request (REQ): The request schema object or dict for query.
Returns:
Optional[RES]: The response schema object.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
result = query.first()
if result is None:
return None
return self.to_response(result)
def get_list(self, query_request: QUERY_SPEC) -> List[RES]:
"""Get a list of entity objects.
Args:
query_request (REQ): The request schema object or dict for query.
Returns:
List[RES]: The response schema object.
"""
with self.session() as session:
result_list = self._get_entity_list(session, query_request)
return [self.to_response(item) for item in result_list]
def _get_entity_list(self, session: Session, query_request: QUERY_SPEC) -> List[T]:
"""Get a list of entity objects.
Args:
session (Session): The session object.
query_request (REQ): The request schema object or dict for query.
Returns:
List[RES]: The response schema object.
"""
query = self._create_query_object(session, query_request)
result_list = query.all()
return result_list
def get_list_page(
self, query_request: QUERY_SPEC, page: int, page_size: int
) -> PaginationResult[RES]:
"""Get a page of entity objects.
Args:
query_request (REQ): The request schema object or dict for query.
page (int): The page number.
page_size (int): The page size.
Returns:
PaginationResult: The pagination result.
"""
with self.session() as session:
query = self._create_query_object(session, query_request)
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
items = [self.to_response(item) for item in items]
total_pages = (total_count + page_size - 1) // page_size
return PaginationResult(
items=items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
def _create_query_object(
self, session: Session, query_request: QUERY_SPEC
) -> BaseQuery:
"""Create a query object.
Args:
session (Session): The session object.
query_request (QUERY_SPEC): The request schema object or dict for query.
Returns:
BaseQuery: The query object.
"""
model_cls = type(self.from_request(query_request))
query = session.query(model_cls)
query_dict = (
query_request if isinstance(query_request, dict) else query_request.dict()
)
for key, value in query_dict.items():
if value is not None:
query = query.filter(getattr(model_cls, key) == value)
return query