refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import pytest
from typing import Type
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
PaginationResult,
create_model,
BaseModel,
)
from sqlalchemy import Column, Integer, String
@pytest.fixture
def db():
db = DatabaseManager()
db.init_db("sqlite:///:memory:")
return db
@pytest.fixture
def Model(db):
return create_model(db)
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
assert db.engine is not None
assert db.session is not None
with db.session() as session:
assert session is not None
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
assert db.metadata.tables == {}
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
assert list(db.metadata.tables.keys())[0] == "user"
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
with db.session() as session:
user = User.create(name="John Doe")
session.add(user)
session.commit()
# Read
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
assert user is not None
# Update
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
user.update(name="Jane Doe")
# Delete
with db.session() as session:
user = session.query(User).filter_by(name="Jane Doe").first()
user.delete()
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
user = User.create(name="John Doe")
assert User.get(user.id) is not None
users = User.all()
assert len(users) == 1
# Update
user.update(name="Bob Doe")
assert User.get(user.id).name == "Bob Doe"
user = User.get(user.id)
user.delete()
assert User.get(user.id) is None
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# 添加数据
with db.session() as session:
for i in range(30):
user = User(name=f"User {i}")
session.add(user)
session.commit()
users_page_1 = User.query.paginate_query(page=1, per_page=10)
assert len(users_page_1.items) == 10
assert users_page_1.total_pages == 3
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
with pytest.raises(ValueError):
User.query.paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)

View File

@@ -0,0 +1,173 @@
from typing import Dict, Type
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy import Column, Integer, String
import pytest
from dbgpt.core.interface.storage import (
StorageItem,
ResourceIdentifier,
StorageItemAdapter,
QuerySpec,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
from dbgpt.util.serialization.json_serialization import JsonSerializer
Base = declarative_base()
class MockModel(Base):
"""The SQLAlchemy model for the mock data."""
__tablename__ = "mock_data"
id = Column(Integer, primary_key=True)
data = Column(String)
class MockStorageItem(StorageItem):
"""The mock storage item."""
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, MockStorageItem):
raise ValueError("other must be a MockStorageItem")
self.data = other.data
def __init__(self, identifier: ResourceIdentifier, data: str):
self._identifier = identifier
self.data = data
@property
def identifier(self) -> ResourceIdentifier:
return self._identifier
def to_dict(self) -> Dict:
return {"identifier": self._identifier, "data": self.data}
def serialize(self) -> bytes:
return str(self.data).encode()
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
"""The adapter for the mock storage item."""
def to_storage_format(self, item: MockStorageItem) -> MockModel:
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
def from_storage_format(self, model: MockModel) -> MockStorageItem:
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
def get_query_for_identifier(
self,
storage_format: Type[MockModel],
resource_id: ResourceIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise ValueError("session is required for this adapter")
return session.query(storage_format).filter(
storage_format.id == int(resource_id.str_identifier)
)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def sqlalchemy_storage(db_url, serializer):
adapter = MockStorageItemAdapter()
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
Base.metadata.create_all(storage.db_manager.engine)
return storage
def test_save_and_load(sqlalchemy_storage):
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
sqlalchemy_storage.save(item)
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
assert loaded_item.data == "test_data"
def test_delete(sqlalchemy_storage):
resource_id = MockResourceIdentifier("1")
sqlalchemy_storage.delete(resource_id)
# Make sure the item is deleted
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
def test_query_with_various_conditions(sqlalchemy_storage):
# Add multiple items for testing
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test query with single condition
query_spec = QuerySpec(conditions={"data": "test_data_2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
assert results[0].data == "test_data_2"
# Test not existing condition
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
# Test query with multiple conditions
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
def test_query_nonexistent_item(sqlalchemy_storage):
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
def test_count_items(sqlalchemy_storage):
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test count without conditions
query_spec = QuerySpec(conditions={})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 5
# Test count with conditions
query_spec = QuerySpec(conditions={"data": "test_data_2"})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 1
def test_paginate_query(sqlalchemy_storage):
for i in range(10):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
page_size = 3
page_number = 2
query_spec = QuerySpec(conditions={})
page_result = sqlalchemy_storage.paginate_query(
page_number, page_size, MockStorageItem, query_spec
)
assert len(page_result.items) == page_size
assert page_result.page == page_number
assert page_result.total_pages == 4
assert page_result.total_count == 10