mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 02:51:07 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
@@ -1,18 +1,27 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Any, Optional
|
||||
from typing import TypeVar, Generic, Any, Optional, Dict, Union, List
|
||||
from sqlalchemy.orm.session import Session
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
# The entity type
|
||||
T = TypeVar("T")
|
||||
# The request schema type
|
||||
REQ = TypeVar("REQ")
|
||||
# The response schema type
|
||||
RES = TypeVar("RES")
|
||||
|
||||
from .db_manager import db, DatabaseManager
|
||||
from .db_manager import db, DatabaseManager, BaseQuery
|
||||
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
||||
|
||||
|
||||
class BaseDao(Generic[T, REQ, RES]):
|
||||
"""The base class for all DAOs.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
class UserDao(BaseDao[User]):
|
||||
class UserDao(BaseDao):
|
||||
def get_user_by_name(self, name: str) -> User:
|
||||
with self.session() as session:
|
||||
return session.query(User).filter(User.name == name).first()
|
||||
@@ -70,3 +79,184 @@ class BaseDao(Generic[T]):
|
||||
"""
|
||||
with self._db_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def from_request(self, request: QUERY_SPEC) -> T:
|
||||
"""Convert a request schema object to an entity object.
|
||||
|
||||
Args:
|
||||
request (REQ): The request schema object or dict for query.
|
||||
|
||||
Returns:
|
||||
T: The entity object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_request(self, entity: T) -> REQ:
|
||||
"""Convert an entity object to a request schema object.
|
||||
|
||||
Args:
|
||||
entity (T): The entity object.
|
||||
|
||||
Returns:
|
||||
REQ: The request schema object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def from_response(self, response: RES) -> T:
|
||||
"""Convert a response schema object to an entity object.
|
||||
|
||||
Args:
|
||||
response (RES): The response schema object.
|
||||
|
||||
Returns:
|
||||
T: The entity object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_response(self, entity: T) -> RES:
|
||||
"""Convert an entity object to a response schema object.
|
||||
|
||||
Args:
|
||||
entity (T): The entity object.
|
||||
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create(self, request: REQ) -> RES:
|
||||
"""Create an entity object.
|
||||
|
||||
Args:
|
||||
request (REQ): The request schema object.
|
||||
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
entry = self.from_request(request)
|
||||
with self.session() as session:
|
||||
session.add(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
|
||||
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
|
||||
"""Update an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
update_request (REQ): The request schema object for update.
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
entry = query.first()
|
||||
if entry is None:
|
||||
raise Exception("Invalid request")
|
||||
for key, value in update_request.dict().items():
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
|
||||
def delete(self, query_request: QUERY_SPEC) -> None:
|
||||
"""Delete an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
"""
|
||||
with self.session() as session:
|
||||
result_list = self._get_entity_list(session, query_request)
|
||||
if len(result_list) != 1:
|
||||
raise ValueError(
|
||||
f"Delete request should return one result, but got {len(result_list)}"
|
||||
)
|
||||
session.delete(result_list[0])
|
||||
|
||||
def get_one(self, query_request: QUERY_SPEC) -> Optional[RES]:
|
||||
"""Get an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
|
||||
Returns:
|
||||
Optional[RES]: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
result = query.first()
|
||||
if result is None:
|
||||
return None
|
||||
return self.to_response(result)
|
||||
|
||||
def get_list(self, query_request: QUERY_SPEC) -> List[RES]:
|
||||
"""Get a list of entity objects.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
Returns:
|
||||
List[RES]: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
result_list = self._get_entity_list(session, query_request)
|
||||
return [self.to_response(item) for item in result_list]
|
||||
|
||||
def _get_entity_list(self, session: Session, query_request: QUERY_SPEC) -> List[T]:
|
||||
"""Get a list of entity objects.
|
||||
|
||||
Args:
|
||||
session (Session): The session object.
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
Returns:
|
||||
List[RES]: The response schema object.
|
||||
"""
|
||||
query = self._create_query_object(session, query_request)
|
||||
result_list = query.all()
|
||||
return result_list
|
||||
|
||||
def get_list_page(
|
||||
self, query_request: QUERY_SPEC, page: int, page_size: int
|
||||
) -> PaginationResult[RES]:
|
||||
"""Get a page of entity objects.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
page (int): The page number.
|
||||
page_size (int): The page size.
|
||||
|
||||
Returns:
|
||||
PaginationResult: The pagination result.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
total_count = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size)
|
||||
items = [self.to_response(item) for item in items]
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def _create_query_object(
|
||||
self, session: Session, query_request: QUERY_SPEC
|
||||
) -> BaseQuery:
|
||||
"""Create a query object.
|
||||
|
||||
Args:
|
||||
session (Session): The session object.
|
||||
query_request (QUERY_SPEC): The request schema object or dict for query.
|
||||
Returns:
|
||||
BaseQuery: The query object.
|
||||
"""
|
||||
model_cls = type(self.from_request(query_request))
|
||||
query = session.query(model_cls)
|
||||
query_dict = (
|
||||
query_request if isinstance(query_request, dict) else query_request.dict()
|
||||
)
|
||||
for key, value in query_dict.items():
|
||||
if value is not None:
|
||||
query = query.filter(getattr(model_cls, key) == value)
|
||||
return query
|
||||
|
Reference in New Issue
Block a user