mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-18 07:30:40 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
import pytest
|
||||
from typing import Type
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
DatabaseManager,
|
||||
PaginationResult,
|
||||
create_model,
|
||||
BaseModel,
|
||||
)
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
db = DatabaseManager()
|
||||
db.init_db("sqlite:///:memory:")
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Model(db):
|
||||
return create_model(db)
|
||||
|
||||
|
||||
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
assert db.engine is not None
|
||||
assert db.session is not None
|
||||
|
||||
with db.session() as session:
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
assert db.metadata.tables == {}
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
assert list(db.metadata.tables.keys())[0] == "user"
|
||||
|
||||
|
||||
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# Create
|
||||
with db.session() as session:
|
||||
user = User.create(name="John Doe")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Read
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
assert user is not None
|
||||
|
||||
# Update
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
user.update(name="Jane Doe")
|
||||
|
||||
# Delete
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="Jane Doe").first()
|
||||
user.delete()
|
||||
|
||||
|
||||
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# Create
|
||||
user = User.create(name="John Doe")
|
||||
assert User.get(user.id) is not None
|
||||
users = User.all()
|
||||
assert len(users) == 1
|
||||
|
||||
# Update
|
||||
user.update(name="Bob Doe")
|
||||
assert User.get(user.id).name == "Bob Doe"
|
||||
|
||||
user = User.get(user.id)
|
||||
user.delete()
|
||||
assert User.get(user.id) is None
|
||||
|
||||
|
||||
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# 添加数据
|
||||
with db.session() as session:
|
||||
for i in range(30):
|
||||
user = User(name=f"User {i}")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
users_page_1 = User.query.paginate_query(page=1, per_page=10)
|
||||
assert len(users_page_1.items) == 10
|
||||
assert users_page_1.total_pages == 3
|
||||
|
||||
|
||||
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=0, per_page=10)
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=1, per_page=-1)
|
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import Dict, Type
|
||||
from sqlalchemy.orm import declarative_base, Session
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import (
|
||||
StorageItem,
|
||||
ResourceIdentifier,
|
||||
StorageItemAdapter,
|
||||
QuerySpec,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
|
||||
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class MockModel(Base):
|
||||
"""The SQLAlchemy model for the mock data."""
|
||||
|
||||
__tablename__ = "mock_data"
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(String)
|
||||
|
||||
|
||||
class MockStorageItem(StorageItem):
|
||||
"""The mock storage item."""
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
if not isinstance(other, MockStorageItem):
|
||||
raise ValueError("other must be a MockStorageItem")
|
||||
self.data = other.data
|
||||
|
||||
def __init__(self, identifier: ResourceIdentifier, data: str):
|
||||
self._identifier = identifier
|
||||
self.data = data
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
return self._identifier
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"identifier": self._identifier, "data": self.data}
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return str(self.data).encode()
|
||||
|
||||
|
||||
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
|
||||
"""The adapter for the mock storage item."""
|
||||
|
||||
def to_storage_format(self, item: MockStorageItem) -> MockModel:
|
||||
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
|
||||
|
||||
def from_storage_format(self, model: MockModel) -> MockStorageItem:
|
||||
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[MockModel],
|
||||
resource_id: ResourceIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise ValueError("session is required for this adapter")
|
||||
return session.query(storage_format).filter(
|
||||
storage_format.id == int(resource_id.str_identifier)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_url():
|
||||
"""Use in-memory SQLite database for testing"""
|
||||
return "sqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlalchemy_storage(db_url, serializer):
|
||||
adapter = MockStorageItemAdapter()
|
||||
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
|
||||
Base.metadata.create_all(storage.db_manager.engine)
|
||||
return storage
|
||||
|
||||
|
||||
def test_save_and_load(sqlalchemy_storage):
|
||||
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
|
||||
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
|
||||
assert loaded_item.data == "test_data"
|
||||
|
||||
|
||||
def test_delete(sqlalchemy_storage):
|
||||
resource_id = MockResourceIdentifier("1")
|
||||
|
||||
sqlalchemy_storage.delete(resource_id)
|
||||
# Make sure the item is deleted
|
||||
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
|
||||
|
||||
|
||||
def test_query_with_various_conditions(sqlalchemy_storage):
|
||||
# Add multiple items for testing
|
||||
for i in range(5):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
# Test query with single condition
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 1
|
||||
assert results[0].data == "test_data_2"
|
||||
|
||||
# Test not existing condition
|
||||
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 0
|
||||
|
||||
# Test query with multiple conditions
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_query_nonexistent_item(sqlalchemy_storage):
|
||||
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_count_items(sqlalchemy_storage):
|
||||
for i in range(5):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
# Test count without conditions
|
||||
query_spec = QuerySpec(conditions={})
|
||||
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||
assert total_count == 5
|
||||
|
||||
# Test count with conditions
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||
assert total_count == 1
|
||||
|
||||
|
||||
def test_paginate_query(sqlalchemy_storage):
|
||||
for i in range(10):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
page_size = 3
|
||||
page_number = 2
|
||||
|
||||
query_spec = QuerySpec(conditions={})
|
||||
page_result = sqlalchemy_storage.paginate_query(
|
||||
page_number, page_size, MockStorageItem, query_spec
|
||||
)
|
||||
|
||||
assert len(page_result.items) == page_size
|
||||
assert page_result.page == page_number
|
||||
assert page_result.total_pages == 4
|
||||
assert page_result.total_count == 10
|
Reference in New Issue
Block a user