refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -1 +1,17 @@
from dbgpt.storage.metadata.db_manager import (
db,
Model,
DatabaseManager,
create_model,
BaseModel,
)
from dbgpt.storage.metadata._base_dao import BaseDao
__ALL__ = [
"db",
"Model",
"DatabaseManager",
"create_model",
"BaseModel",
"BaseDao",
]

View File

@@ -1,25 +1,72 @@
from typing import TypeVar, Generic, Any
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from typing import TypeVar, Generic, Any, Optional
from sqlalchemy.orm.session import Session
T = TypeVar("T")
from .db_manager import db, DatabaseManager
class BaseDao(Generic[T]):
"""The base class for all DAOs.
Examples:
.. code-block:: python
class UserDao(BaseDao[User]):
def get_user_by_name(self, name: str) -> User:
with self.session() as session:
return session.query(User).filter(User.name == name).first()
def get_user_by_id(self, id: int) -> User:
with self.session() as session:
return User.get(id)
def create_user(self, name: str) -> User:
return User.create(**{"name": name})
Args:
db_manager (DatabaseManager, optional): The database manager. Defaults to None.
If None, the default database manager(db) will be used.
"""
def __init__(
self,
orm_base=None,
database: str = None,
db_engine: Any = None,
session: Any = None,
db_manager: Optional[DatabaseManager] = None,
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
self._database = database
self._db_manager = db_manager or db
self._db_engine = db_engine
self._session = session
def get_raw_session(self) -> Session:
"""Get a raw session object.
def get_session(self):
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine)
session = Session()
return session
Your should commit or rollback the session manually.
We suggest you use :meth:`session` instead.
Example:
.. code-block:: python
user = User(name="Edward Snowden")
session = self.get_raw_session()
session.add(user)
session.commit()
session.close()
"""
return self._db_manager._session()
@contextmanager
def session(self) -> Session:
"""Provide a transactional scope around a series of operations.
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
Example:
.. code-block:: python
with self.session() as session:
session.query(User).filter(User.name == 'Edward Snowden').first()
Returns:
Session: A session object.
Raises:
Exception: Any exception will be raised.
"""
with self._db_manager.session() as session:
yield session

View File

@@ -0,0 +1,432 @@
from __future__ import annotations
import abc
from contextlib import contextmanager
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
import logging
from sqlalchemy import create_engine, URL, Engine
from sqlalchemy import orm, inspect, MetaData
from sqlalchemy.orm import (
scoped_session,
sessionmaker,
Session,
declarative_base,
DeclarativeMeta,
)
from sqlalchemy.orm.session import _PKIdentityArgument
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.pool import QueuePool
from dbgpt.util.string_utils import _to_str
from dbgpt.util.pagination_utils import PaginationResult
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __init__(self, db_manager: "DatabaseManager"):
self._db_manager = db_manager
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=self._db_manager._session())
except UnmappedClassError:
return None
class BaseQuery(orm.Query):
def paginate_query(
self, page: Optional[int] = 1, per_page: Optional[int] = 20
) -> PaginationResult:
"""Paginate the query.
Example:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
pagination = session.query(User).paginate_query(page=1, page_size=10)
print(pagination)
# Or you can use the query object
with db.session() as session:
pagination = User.query.paginate_query(page=1, page_size=10)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
Returns:
PaginationResult: The pagination result.
"""
if page < 1:
raise ValueError("Page must be greater than 0")
if per_page < 0:
raise ValueError("Per page must be greater than 0")
items = self.limit(per_page).offset((page - 1) * per_page).all()
total = self.order_by(None).count()
total_pages = (total - 1) // per_page + 1
return PaginationResult(
items=items,
total_count=total,
total_pages=total_pages,
page=page,
page_size=per_page,
)
class _Model:
"""Base class for SQLAlchemy declarative base model.
With this class, we can use the query object to query the database.
Examples:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
# User is an instance of _Model, and we can use the query object to query the database.
User.query.filter(User.name == "test").all()
"""
query_class = None
query: Optional[BaseQuery] = None
def __repr__(self):
identity = inspect(self).identity
if identity is None:
pk = "(transient {0})".format(id(self))
else:
pk = ", ".join(_to_str(value) for value in identity)
return "<{0} {1}>".format(type(self).__name__, pk)
class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager()
# Use sqlite with memory storage.
url = f"sqlite:///:memory:"
engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
db.init_db(url, engine_args=engine_args)
Model = create_model(db)
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
session.add(User(name="test", fullname="test"))
# db will commit the session automatically default.
# session.commit()
print(User.query.filter(User.name == "test").all())
# Use CURDMixin APIs to create, update, delete, query the database.
with db.session() as session:
User.create(**{"name": "test1", "fullname": "test1"})
User.create(**{"name": "test2", "fullname": "test1"})
users = User.all()
print(users)
user = users[0]
user.update(**{"name": "test1_1111"})
user2 = users[1]
# Update user2 by save
user2.name = "test2_1111"
user2.save()
# Delete user2
user2.delete()
"""
Query = BaseQuery
def __init__(self):
self._db_url = None
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
self._engine: Optional[Engine] = None
self._session: Optional[scoped_session] = None
@property
def Model(self) -> _Model:
"""Get the declarative base."""
return self._base
@property
def metadata(self) -> MetaData:
"""Get the metadata."""
return self.Model.metadata
@property
def engine(self):
"""Get the engine.""" ""
return self._engine
@contextmanager
def session(self) -> Session:
"""Get the session with context manager.
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
Example:
>>> with db.session() as session:
>>> session.query(...)
Returns:
Session: The session.
Raises:
RuntimeError: The database manager is not initialized.
Exception: Any exception.
"""
if not self._session:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
try:
yield session
session.commit()
except:
session.rollback()
raise
finally:
session.close()
def _make_declarative_base(
self, model: Union[Type[DeclarativeMeta], Type[_Model]]
) -> DeclarativeMeta:
"""Make the declarative base.
Args:
base (DeclarativeMeta): The base class.
Returns:
DeclarativeMeta: The declarative base.
"""
if not isinstance(model, DeclarativeMeta):
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query = _QueryObject(self)
return model
def init_db(
self,
db_url: Union[str, URL],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
"""
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query"):
base.query = _QueryObject(self)
if not getattr(base, "query_class", None):
base.query_class = self.Query
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
self._session = scoped_session(session_factory)
self._base.metadata.bind = self._engine
def init_default_db(
self,
sqlite_path: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
):
"""Initialize the database manager with default config.
Examples:
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
>>> session.query(...)
Args:
sqlite_path (str): The sqlite path.
engine_args (Optional[Dict], optional): The engine arguments.
Defaults to None, if None, we will use connection pool.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
"""
if not engine_args:
engine_args = {}
# Pool class
engine_args["poolclass"] = QueuePool
# The number of connections to keep open inside the connection pool.
engine_args["pool_size"] = 10
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
# pool_size).
engine_args["max_overflow"] = 20
# The number of seconds to wait before giving up on getting a connection from the pool.
engine_args["pool_timeout"] = 30
# Recycle the connection if it has been idle for this many seconds.
engine_args["pool_recycle"] = 3600
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
engine_args["pool_pre_ping"] = True
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
def create_all(self):
self.Model.metadata.create_all(self._engine)
db = DatabaseManager()
"""The global database manager.
Examples:
>>> from dbgpt.storage.metadata import db
>>> sqlite_path = "/tmp/dbgpt.db"
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
>>> session.query(...)
>>> from dbgpt.storage.metadata import db, Model
>>> from urllib.parse import quote_plus as urlquote, quote
>>> db_name = "dbgpt"
>>> db_host = "localhost"
>>> db_port = 3306
>>> user = "root"
>>> password = "123456"
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
>>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
>>> db.init_db(url, engine_args=engine_args)
>>> class User(Model):
>>> __tablename__ = "user"
>>> id = Column(Integer, primary_key=True)
>>> name = Column(String(50))
>>> fullname = Column(String(50))
>>> with db.session() as session:
>>> session.add(User(name="test", fullname="test"))
>>> session.commit()
"""
class BaseCRUDMixin(Generic[T]):
"""The base CRUD mixin."""
__abstract__ = True
@classmethod
def create(cls: Type[T], **kwargs) -> T:
instance = cls(**kwargs)
return instance.save()
@classmethod
def all(cls: Type[T]) -> List[T]:
return cls.query.all()
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@abc.abstractmethod
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
@abc.abstractmethod
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = db_manager._session()
session.add(self)
if commit:
session.commit()
return self
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = db_manager._session()
session.delete(self)
return commit and session.commit()
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
return _NewModel
Model = create_model(db)
def initialize_db(
db_url: Union[str, URL],
db_name: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False,
) -> DatabaseManager:
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
db_name (str): The database name.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
db.init_db(db_url, engine_args, base)
if try_to_create_db:
try:
db.create_all()
except Exception as e:
logger.error(f"Failed to create database {db_name}: {e}")
return db

View File

@@ -0,0 +1,128 @@
from contextlib import contextmanager
from typing import Type, List, Optional, Union, Dict
from dbgpt.core import Serializer
from dbgpt.core.interface.storage import (
StorageInterface,
QuerySpec,
ResourceIdentifier,
StorageItemAdapter,
T,
)
from sqlalchemy import URL
from sqlalchemy.orm import Session, DeclarativeMeta
from .db_manager import BaseModel, DatabaseManager, BaseQuery
def _copy_public_properties(src: BaseModel, dest: BaseModel):
"""Simple copy public properties from src to dest"""
for column in src.__table__.columns:
if column.name != "id":
setattr(dest, column.name, getattr(src, column.name))
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
def __init__(
self,
db_url_or_db: Union[str, URL, DatabaseManager],
model_class: Type[BaseModel],
adapter: StorageItemAdapter[T, BaseModel],
serializer: Optional[Serializer] = None,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
super().__init__(serializer=serializer, adapter=adapter)
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
self.db_manager = db_manager
elif isinstance(db_url_or_db, DatabaseManager):
self.db_manager = db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
self._model_class = model_class
@contextmanager
def session(self) -> Session:
with self.db_manager.session() as session:
yield session
def save(self, data: T) -> None:
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.add(model_instance)
def update(self, data: T) -> None:
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.merge(model_instance)
def save_or_update(self, data: T) -> None:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
new_instance = self.adapter.to_storage_format(data)
_copy_public_properties(new_instance, model_instance)
session.merge(model_instance)
return
self.save(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
return self.adapter.from_storage_format(model_instance)
return None
def delete(self, resource_id: ResourceIdentifier) -> None:
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
session.delete(model_instance)
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
if spec.limit is not None:
query = query.limit(spec.limit)
if spec.offset is not None:
query = query.offset(spec.offset)
model_instances = query.all()
return [
self.adapter.from_storage_format(instance)
for instance in model_instances
]
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data in the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
query = query.filter(getattr(self._model_class, key) == value)
return query.count()

View File

@@ -1,94 +0,0 @@
import os
import sqlite3
import logging
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import command
from alembic.config import Config as AlembicConfig
from urllib.parse import quote
from dbgpt._private.config import Config
from dbgpt.configs.model_config import PILOT_PATH
from urllib.parse import quote_plus as urlquote
logger = logging.getLogger(__name__)
# DB-GPT metadata database config, now support mysql and sqlite
CFG = Config()
default_db_path = os.path.join(PILOT_PATH, "meta_data")
os.makedirs(default_db_path, exist_ok=True)
# Meta Info
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
db_name = META_DATA_DATABASE
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
)
# check and auto create mysqldatabase
try:
# try to connect
with engine_temp.connect() as conn:
# TODO We should consider that the production environment does not have permission to execute the DDL
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
print(f"Already connect '{db_name}'")
except OperationalError as e:
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
engine = create_engine(
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
)
else:
engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
Base = declarative_base()
# Base.metadata.create_all()
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
alembic_cfg.attributes["target_metadata"] = Base.metadata
alembic_cfg.attributes["session"] = session
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
"""Initialize and upgrade database metadata
Args:
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
"""
if disable_alembic_upgrade:
logger.info(
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
)
return
with engine.connect() as connection:
alembic_cfg.attributes["connection"] = connection
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))
command.revision(alembic_cfg, "dbgpt ddl upate", True)
command.upgrade(alembic_cfg, "head")

View File

View File

@@ -0,0 +1,129 @@
from __future__ import annotations
import pytest
from typing import Type
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
PaginationResult,
create_model,
BaseModel,
)
from sqlalchemy import Column, Integer, String
@pytest.fixture
def db():
db = DatabaseManager()
db.init_db("sqlite:///:memory:")
return db
@pytest.fixture
def Model(db):
return create_model(db)
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
assert db.engine is not None
assert db.session is not None
with db.session() as session:
assert session is not None
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
assert db.metadata.tables == {}
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
assert list(db.metadata.tables.keys())[0] == "user"
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
with db.session() as session:
user = User.create(name="John Doe")
session.add(user)
session.commit()
# Read
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
assert user is not None
# Update
with db.session() as session:
user = session.query(User).filter_by(name="John Doe").first()
user.update(name="Jane Doe")
# Delete
with db.session() as session:
user = session.query(User).filter_by(name="Jane Doe").first()
user.delete()
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# Create
user = User.create(name="John Doe")
assert User.get(user.id) is not None
users = User.all()
assert len(users) == 1
# Update
user.update(name="Bob Doe")
assert User.get(user.id).name == "Bob Doe"
user = User.get(user.id)
user.delete()
assert User.get(user.id) is None
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
# 添加数据
with db.session() as session:
for i in range(30):
user = User(name=f"User {i}")
session.add(user)
session.commit()
users_page_1 = User.query.paginate_query(page=1, per_page=10)
assert len(users_page_1.items) == 10
assert users_page_1.total_pages == 3
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
db.create_all()
with pytest.raises(ValueError):
User.query.paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)

View File

@@ -0,0 +1,173 @@
from typing import Dict, Type
from sqlalchemy.orm import declarative_base, Session
from sqlalchemy import Column, Integer, String
import pytest
from dbgpt.core.interface.storage import (
StorageItem,
ResourceIdentifier,
StorageItemAdapter,
QuerySpec,
)
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
from dbgpt.util.serialization.json_serialization import JsonSerializer
Base = declarative_base()
class MockModel(Base):
"""The SQLAlchemy model for the mock data."""
__tablename__ = "mock_data"
id = Column(Integer, primary_key=True)
data = Column(String)
class MockStorageItem(StorageItem):
"""The mock storage item."""
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, MockStorageItem):
raise ValueError("other must be a MockStorageItem")
self.data = other.data
def __init__(self, identifier: ResourceIdentifier, data: str):
self._identifier = identifier
self.data = data
@property
def identifier(self) -> ResourceIdentifier:
return self._identifier
def to_dict(self) -> Dict:
return {"identifier": self._identifier, "data": self.data}
def serialize(self) -> bytes:
return str(self.data).encode()
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
"""The adapter for the mock storage item."""
def to_storage_format(self, item: MockStorageItem) -> MockModel:
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
def from_storage_format(self, model: MockModel) -> MockStorageItem:
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
def get_query_for_identifier(
self,
storage_format: Type[MockModel],
resource_id: ResourceIdentifier,
**kwargs,
):
session: Session = kwargs.get("session")
if session is None:
raise ValueError("session is required for this adapter")
return session.query(storage_format).filter(
storage_format.id == int(resource_id.str_identifier)
)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def db_url():
"""Use in-memory SQLite database for testing"""
return "sqlite:///:memory:"
@pytest.fixture
def sqlalchemy_storage(db_url, serializer):
adapter = MockStorageItemAdapter()
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
Base.metadata.create_all(storage.db_manager.engine)
return storage
def test_save_and_load(sqlalchemy_storage):
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
sqlalchemy_storage.save(item)
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
assert loaded_item.data == "test_data"
def test_delete(sqlalchemy_storage):
resource_id = MockResourceIdentifier("1")
sqlalchemy_storage.delete(resource_id)
# Make sure the item is deleted
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
def test_query_with_various_conditions(sqlalchemy_storage):
# Add multiple items for testing
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test query with single condition
query_spec = QuerySpec(conditions={"data": "test_data_2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
assert results[0].data == "test_data_2"
# Test not existing condition
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
# Test query with multiple conditions
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
def test_query_nonexistent_item(sqlalchemy_storage):
query_spec = QuerySpec(conditions={"data": "nonexistent"})
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
assert len(results) == 0
def test_count_items(sqlalchemy_storage):
for i in range(5):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
# Test count without conditions
query_spec = QuerySpec(conditions={})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 5
# Test count with conditions
query_spec = QuerySpec(conditions={"data": "test_data_2"})
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
assert total_count == 1
def test_paginate_query(sqlalchemy_storage):
for i in range(10):
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
sqlalchemy_storage.save(item)
page_size = 3
page_number = 2
query_spec = QuerySpec(conditions={})
page_result = sqlalchemy_storage.paginate_query(
page_number, page_size, MockStorageItem, query_spec
)
assert len(page_result.items) == page_size
assert page_result.page == page_number
assert page_result.total_pages == 4
assert page_result.total_count == 10