mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
refactor: Refactor storage and new serve template (#947)
This commit is contained in:
@@ -1,18 +1,27 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Any, Optional
|
||||
from typing import TypeVar, Generic, Any, Optional, Dict, Union, List
|
||||
from sqlalchemy.orm.session import Session
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
# The entity type
|
||||
T = TypeVar("T")
|
||||
# The request schema type
|
||||
REQ = TypeVar("REQ")
|
||||
# The response schema type
|
||||
RES = TypeVar("RES")
|
||||
|
||||
from .db_manager import db, DatabaseManager
|
||||
from .db_manager import db, DatabaseManager, BaseQuery
|
||||
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
||||
|
||||
|
||||
class BaseDao(Generic[T, REQ, RES]):
|
||||
"""The base class for all DAOs.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
class UserDao(BaseDao[User]):
|
||||
class UserDao(BaseDao):
|
||||
def get_user_by_name(self, name: str) -> User:
|
||||
with self.session() as session:
|
||||
return session.query(User).filter(User.name == name).first()
|
||||
@@ -70,3 +79,184 @@ class BaseDao(Generic[T]):
|
||||
"""
|
||||
with self._db_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def from_request(self, request: QUERY_SPEC) -> T:
|
||||
"""Convert a request schema object to an entity object.
|
||||
|
||||
Args:
|
||||
request (REQ): The request schema object or dict for query.
|
||||
|
||||
Returns:
|
||||
T: The entity object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_request(self, entity: T) -> REQ:
|
||||
"""Convert an entity object to a request schema object.
|
||||
|
||||
Args:
|
||||
entity (T): The entity object.
|
||||
|
||||
Returns:
|
||||
REQ: The request schema object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def from_response(self, response: RES) -> T:
|
||||
"""Convert a response schema object to an entity object.
|
||||
|
||||
Args:
|
||||
response (RES): The response schema object.
|
||||
|
||||
Returns:
|
||||
T: The entity object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_response(self, entity: T) -> RES:
|
||||
"""Convert an entity object to a response schema object.
|
||||
|
||||
Args:
|
||||
entity (T): The entity object.
|
||||
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def create(self, request: REQ) -> RES:
|
||||
"""Create an entity object.
|
||||
|
||||
Args:
|
||||
request (REQ): The request schema object.
|
||||
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
entry = self.from_request(request)
|
||||
with self.session() as session:
|
||||
session.add(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
|
||||
def update(self, query_request: QUERY_SPEC, update_request: REQ) -> RES:
|
||||
"""Update an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
update_request (REQ): The request schema object for update.
|
||||
Returns:
|
||||
RES: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
entry = query.first()
|
||||
if entry is None:
|
||||
raise Exception("Invalid request")
|
||||
for key, value in update_request.dict().items():
|
||||
setattr(entry, key, value)
|
||||
session.merge(entry)
|
||||
return self.get_one(self.to_request(entry))
|
||||
|
||||
def delete(self, query_request: QUERY_SPEC) -> None:
|
||||
"""Delete an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
"""
|
||||
with self.session() as session:
|
||||
result_list = self._get_entity_list(session, query_request)
|
||||
if len(result_list) != 1:
|
||||
raise ValueError(
|
||||
f"Delete request should return one result, but got {len(result_list)}"
|
||||
)
|
||||
session.delete(result_list[0])
|
||||
|
||||
def get_one(self, query_request: QUERY_SPEC) -> Optional[RES]:
|
||||
"""Get an entity object.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
|
||||
Returns:
|
||||
Optional[RES]: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
result = query.first()
|
||||
if result is None:
|
||||
return None
|
||||
return self.to_response(result)
|
||||
|
||||
def get_list(self, query_request: QUERY_SPEC) -> List[RES]:
|
||||
"""Get a list of entity objects.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
Returns:
|
||||
List[RES]: The response schema object.
|
||||
"""
|
||||
with self.session() as session:
|
||||
result_list = self._get_entity_list(session, query_request)
|
||||
return [self.to_response(item) for item in result_list]
|
||||
|
||||
def _get_entity_list(self, session: Session, query_request: QUERY_SPEC) -> List[T]:
|
||||
"""Get a list of entity objects.
|
||||
|
||||
Args:
|
||||
session (Session): The session object.
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
Returns:
|
||||
List[RES]: The response schema object.
|
||||
"""
|
||||
query = self._create_query_object(session, query_request)
|
||||
result_list = query.all()
|
||||
return result_list
|
||||
|
||||
def get_list_page(
|
||||
self, query_request: QUERY_SPEC, page: int, page_size: int
|
||||
) -> PaginationResult[RES]:
|
||||
"""Get a page of entity objects.
|
||||
|
||||
Args:
|
||||
query_request (REQ): The request schema object or dict for query.
|
||||
page (int): The page number.
|
||||
page_size (int): The page size.
|
||||
|
||||
Returns:
|
||||
PaginationResult: The pagination result.
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = self._create_query_object(session, query_request)
|
||||
total_count = query.count()
|
||||
items = query.offset((page - 1) * page_size).limit(page_size)
|
||||
items = [self.to_response(item) for item in items]
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def _create_query_object(
|
||||
self, session: Session, query_request: QUERY_SPEC
|
||||
) -> BaseQuery:
|
||||
"""Create a query object.
|
||||
|
||||
Args:
|
||||
session (Session): The session object.
|
||||
query_request (QUERY_SPEC): The request schema object or dict for query.
|
||||
Returns:
|
||||
BaseQuery: The query object.
|
||||
"""
|
||||
model_cls = type(self.from_request(query_request))
|
||||
query = session.query(model_cls)
|
||||
query_dict = (
|
||||
query_request if isinstance(query_request, dict) else query_request.dict()
|
||||
)
|
||||
for key, value in query_dict.items():
|
||||
if value is not None:
|
||||
query = query.filter(getattr(model_cls, key) == value)
|
||||
return query
|
||||
|
@@ -103,7 +103,8 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
with self.session() as session:
|
||||
query = session.query(self._model_class)
|
||||
for key, value in spec.conditions.items():
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
if value is not None:
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
if spec.limit is not None:
|
||||
query = query.limit(spec.limit)
|
||||
if spec.offset is not None:
|
||||
@@ -124,5 +125,6 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
with self.session() as session:
|
||||
query = session.query(self._model_class)
|
||||
for key, value in spec.conditions.items():
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
if value is not None:
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
return query.count()
|
||||
|
152
dbgpt/storage/metadata/tests/test_base_dao.py
Normal file
152
dbgpt/storage/metadata/tests/test_base_dao.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Type, Optional, Union, Dict, Any
|
||||
import pytest
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from dbgpt._private.pydantic import BaseModel as PydanticBaseModel, Field
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
DatabaseManager,
|
||||
PaginationResult,
|
||||
create_model,
|
||||
BaseModel,
|
||||
)
|
||||
|
||||
from .._base_dao import BaseDao
|
||||
|
||||
|
||||
class UserRequest(PydanticBaseModel):
|
||||
name: str = Field(..., description="User name")
|
||||
age: Optional[int] = Field(default=-1, description="User age")
|
||||
password: Optional[str] = Field(default="", description="User password")
|
||||
|
||||
|
||||
class UserResponse(PydanticBaseModel):
|
||||
id: int = Field(..., description="User id")
|
||||
name: str = Field(..., description="User name")
|
||||
age: Optional[int] = Field(default=-1, description="User age")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
db = DatabaseManager()
|
||||
db.init_db("sqlite:///:memory:")
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Model(db):
|
||||
return create_model(db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def User(Model):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50), unique=True)
|
||||
age = Column(Integer)
|
||||
password = Column(String(50))
|
||||
|
||||
return User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_req():
|
||||
return UserRequest(name="Edward Snowden", age=30, password="123456")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_dao(db, User):
|
||||
class UserDao(BaseDao[User, UserRequest, UserResponse]):
|
||||
def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User:
|
||||
if isinstance(request, UserRequest):
|
||||
return User(**request.dict())
|
||||
else:
|
||||
return User(**request)
|
||||
|
||||
def to_request(self, entity: User) -> UserRequest:
|
||||
return UserRequest(
|
||||
name=entity.name, age=entity.age, password=entity.password
|
||||
)
|
||||
|
||||
def from_response(self, response: UserResponse) -> User:
|
||||
return User(**response.dict())
|
||||
|
||||
def to_response(self, entity: User):
|
||||
return UserResponse(id=entity.id, name=entity.name, age=entity.age)
|
||||
|
||||
db.create_all()
|
||||
return UserDao(db)
|
||||
|
||||
|
||||
def test_create_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
||||
user_dao.create(user_req)
|
||||
with db.session() as session:
|
||||
user = session.query(User).first()
|
||||
assert user.name == user_req.name
|
||||
assert user.age == user_req.age
|
||||
assert user.password == user_req.password
|
||||
|
||||
|
||||
def test_update_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
||||
# Create a user
|
||||
created_user_response = user_dao.create(user_req)
|
||||
|
||||
# Update the user
|
||||
updated_req = UserRequest(name=user_req.name, age=35, password="newpassword")
|
||||
updated_user = user_dao.update(
|
||||
query_request={"name": user_req.name}, update_request=updated_req
|
||||
)
|
||||
assert updated_user.id == created_user_response.id
|
||||
assert updated_user.age == 35
|
||||
|
||||
# Verify that the user is updated in the database
|
||||
with db.session() as session:
|
||||
user = session.query(User).get(created_user_response.id)
|
||||
assert user.age == 35
|
||||
|
||||
|
||||
def test_get_user(db: DatabaseManager, User: Type[BaseModel], user_dao, user_req):
|
||||
# Create a user
|
||||
created_user_response = user_dao.create(user_req)
|
||||
|
||||
# Query the user
|
||||
fetched_user = user_dao.get_one({"name": user_req.name})
|
||||
assert fetched_user.id == created_user_response.id
|
||||
assert fetched_user.name == user_req.name
|
||||
assert fetched_user.age == user_req.age
|
||||
|
||||
|
||||
def test_get_list_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
|
||||
for i in range(20):
|
||||
user_dao.create(
|
||||
UserRequest(
|
||||
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
|
||||
)
|
||||
)
|
||||
# Query the user
|
||||
fetched_user = user_dao.get_list({"password": "123456"})
|
||||
assert len(fetched_user) == 10
|
||||
|
||||
|
||||
def test_get_list_page_user(db: DatabaseManager, User: Type[BaseModel], user_dao):
|
||||
for i in range(20):
|
||||
user_dao.create(
|
||||
UserRequest(
|
||||
name=f"User {i}", age=i, password="123456" if i % 2 == 0 else "abcdefg"
|
||||
)
|
||||
)
|
||||
page_result: PaginationResult = user_dao.get_list_page(
|
||||
{"password": "123456"}, page=1, page_size=3
|
||||
)
|
||||
assert page_result.total_count == 10
|
||||
assert page_result.total_pages == 4
|
||||
assert len(page_result.items) == 3
|
||||
assert page_result.items[0].name == "User 0"
|
||||
|
||||
# Test query next page
|
||||
page_result: PaginationResult = user_dao.get_list_page(
|
||||
{"password": "123456"}, page=2, page_size=3
|
||||
)
|
||||
assert page_result.total_count == 10
|
||||
assert page_result.total_pages == 4
|
||||
assert len(page_result.items) == 3
|
||||
assert page_result.items[0].name == "User 6"
|
Reference in New Issue
Block a user