mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -1 +1,17 @@
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
db,
|
||||
Model,
|
||||
DatabaseManager,
|
||||
create_model,
|
||||
BaseModel,
|
||||
)
|
||||
from dbgpt.storage.metadata._base_dao import BaseDao
|
||||
|
||||
__ALL__ = [
|
||||
"db",
|
||||
"Model",
|
||||
"DatabaseManager",
|
||||
"create_model",
|
||||
"BaseModel",
|
||||
"BaseDao",
|
||||
]
|
||||
|
@@ -1,25 +1,72 @@
|
||||
from typing import TypeVar, Generic, Any
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Any, Optional
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
from .db_manager import db, DatabaseManager
|
||||
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
"""The base class for all DAOs.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
class UserDao(BaseDao[User]):
|
||||
def get_user_by_name(self, name: str) -> User:
|
||||
with self.session() as session:
|
||||
return session.query(User).filter(User.name == name).first()
|
||||
|
||||
def get_user_by_id(self, id: int) -> User:
|
||||
with self.session() as session:
|
||||
return User.get(id)
|
||||
|
||||
def create_user(self, name: str) -> User:
|
||||
return User.create(**{"name": name})
|
||||
Args:
|
||||
db_manager (DatabaseManager, optional): The database manager. Defaults to None.
|
||||
If None, the default database manager(db) will be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orm_base=None,
|
||||
database: str = None,
|
||||
db_engine: Any = None,
|
||||
session: Any = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
self._database = database
|
||||
self._db_manager = db_manager or db
|
||||
|
||||
self._db_engine = db_engine
|
||||
self._session = session
|
||||
def get_raw_session(self) -> Session:
|
||||
"""Get a raw session object.
|
||||
|
||||
def get_session(self):
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=self._db_engine)
|
||||
session = Session()
|
||||
return session
|
||||
Your should commit or rollback the session manually.
|
||||
We suggest you use :meth:`session` instead.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
user = User(name="Edward Snowden")
|
||||
session = self.get_raw_session()
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.close()
|
||||
"""
|
||||
return self._db_manager._session()
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
"""Provide a transactional scope around a series of operations.
|
||||
|
||||
If raise an exception, the session will be roll back automatically, otherwise it will be committed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
with self.session() as session:
|
||||
session.query(User).filter(User.name == 'Edward Snowden').first()
|
||||
|
||||
Returns:
|
||||
Session: A session object.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception will be raised.
|
||||
"""
|
||||
with self._db_manager.session() as session:
|
||||
yield session
|
||||
|
432
dbgpt/storage/metadata/db_manager.py
Normal file
432
dbgpt/storage/metadata/db_manager.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
|
||||
import logging
|
||||
from sqlalchemy import create_engine, URL, Engine
|
||||
from sqlalchemy import orm, inspect, MetaData
|
||||
from sqlalchemy.orm import (
|
||||
scoped_session,
|
||||
sessionmaker,
|
||||
Session,
|
||||
declarative_base,
|
||||
DeclarativeMeta,
|
||||
)
|
||||
from sqlalchemy.orm.session import _PKIdentityArgument
|
||||
from sqlalchemy.orm.exc import UnmappedClassError
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from dbgpt.util.string_utils import _to_str
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class _QueryObject:
|
||||
"""The query object."""
|
||||
|
||||
def __init__(self, db_manager: "DatabaseManager"):
|
||||
self._db_manager = db_manager
|
||||
|
||||
def __get__(self, obj, type):
|
||||
try:
|
||||
mapper = orm.class_mapper(type)
|
||||
if mapper:
|
||||
return type.query_class(mapper, session=self._db_manager._session())
|
||||
except UnmappedClassError:
|
||||
return None
|
||||
|
||||
|
||||
class BaseQuery(orm.Query):
|
||||
def paginate_query(
|
||||
self, page: Optional[int] = 1, per_page: Optional[int] = 20
|
||||
) -> PaginationResult:
|
||||
"""Paginate the query.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
with db.session() as session:
|
||||
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
# Or you can use the query object
|
||||
with db.session() as session:
|
||||
pagination = User.query.paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
Args:
|
||||
page (Optional[int], optional): The page number. Defaults to 1.
|
||||
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
|
||||
Returns:
|
||||
PaginationResult: The pagination result.
|
||||
"""
|
||||
if page < 1:
|
||||
raise ValueError("Page must be greater than 0")
|
||||
if per_page < 0:
|
||||
raise ValueError("Per page must be greater than 0")
|
||||
items = self.limit(per_page).offset((page - 1) * per_page).all()
|
||||
total = self.order_by(None).count()
|
||||
total_pages = (total - 1) // per_page + 1
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
total_count=total,
|
||||
total_pages=total_pages,
|
||||
page=page,
|
||||
page_size=per_page,
|
||||
)
|
||||
|
||||
|
||||
class _Model:
|
||||
"""Base class for SQLAlchemy declarative base model.
|
||||
|
||||
With this class, we can use the query object to query the database.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
with db.session() as session:
|
||||
# User is an instance of _Model, and we can use the query object to query the database.
|
||||
User.query.filter(User.name == "test").all()
|
||||
"""
|
||||
|
||||
query_class = None
|
||||
query: Optional[BaseQuery] = None
|
||||
|
||||
def __repr__(self):
|
||||
identity = inspect(self).identity
|
||||
if identity is None:
|
||||
pk = "(transient {0})".format(id(self))
|
||||
else:
|
||||
pk = ", ".join(_to_str(value) for value in identity)
|
||||
return "<{0} {1}>".format(type(self).__name__, pk)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""The database manager.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||
db = DatabaseManager()
|
||||
# Use sqlite with memory storage.
|
||||
url = f"sqlite:///:memory:"
|
||||
engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||
db.init_db(url, engine_args=engine_args)
|
||||
|
||||
Model = create_model(db)
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
with db.session() as session:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
# db will commit the session automatically default.
|
||||
# session.commit()
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
|
||||
|
||||
# Use CURDMixin APIs to create, update, delete, query the database.
|
||||
with db.session() as session:
|
||||
User.create(**{"name": "test1", "fullname": "test1"})
|
||||
User.create(**{"name": "test2", "fullname": "test1"})
|
||||
users = User.all()
|
||||
print(users)
|
||||
user = users[0]
|
||||
user.update(**{"name": "test1_1111"})
|
||||
user2 = users[1]
|
||||
# Update user2 by save
|
||||
user2.name = "test2_1111"
|
||||
user2.save()
|
||||
# Delete user2
|
||||
user2.delete()
|
||||
"""
|
||||
|
||||
Query = BaseQuery
|
||||
|
||||
def __init__(self):
|
||||
self._db_url = None
|
||||
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
|
||||
self._engine: Optional[Engine] = None
|
||||
self._session: Optional[scoped_session] = None
|
||||
|
||||
@property
|
||||
def Model(self) -> _Model:
|
||||
"""Get the declarative base."""
|
||||
return self._base
|
||||
|
||||
@property
|
||||
def metadata(self) -> MetaData:
|
||||
"""Get the metadata."""
|
||||
return self.Model.metadata
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
"""Get the engine.""" ""
|
||||
return self._engine
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
"""Get the session with context manager.
|
||||
|
||||
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
|
||||
|
||||
Example:
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
|
||||
Returns:
|
||||
Session: The session.
|
||||
|
||||
Raises:
|
||||
RuntimeError: The database manager is not initialized.
|
||||
Exception: Any exception.
|
||||
"""
|
||||
if not self._session:
|
||||
raise RuntimeError("The database manager is not initialized.")
|
||||
session = self._session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def _make_declarative_base(
|
||||
self, model: Union[Type[DeclarativeMeta], Type[_Model]]
|
||||
) -> DeclarativeMeta:
|
||||
"""Make the declarative base.
|
||||
|
||||
Args:
|
||||
base (DeclarativeMeta): The base class.
|
||||
|
||||
Returns:
|
||||
DeclarativeMeta: The declarative base.
|
||||
"""
|
||||
if not isinstance(model, DeclarativeMeta):
|
||||
model = declarative_base(cls=model, name="Model")
|
||||
if not getattr(model, "query_class", None):
|
||||
model.query_class = self.Query
|
||||
model.query = _QueryObject(self)
|
||||
return model
|
||||
|
||||
def init_db(
|
||||
self,
|
||||
db_url: Union[str, URL],
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
):
|
||||
"""Initialize the database manager.
|
||||
|
||||
Args:
|
||||
db_url (Union[str, URL]): The database url.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
"""
|
||||
self._db_url = db_url
|
||||
if query_class is not None:
|
||||
self.Query = query_class
|
||||
if base is not None:
|
||||
self._base = base
|
||||
if not hasattr(base, "query"):
|
||||
base.query = _QueryObject(self)
|
||||
if not getattr(base, "query_class", None):
|
||||
base.query_class = self.Query
|
||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||
session_factory = sessionmaker(bind=self._engine)
|
||||
self._session = scoped_session(session_factory)
|
||||
self._base.metadata.bind = self._engine
|
||||
|
||||
def init_default_db(
|
||||
self,
|
||||
sqlite_path: str,
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
):
|
||||
"""Initialize the database manager with default config.
|
||||
|
||||
Examples:
|
||||
>>> db.init_default_db(sqlite_path)
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
|
||||
Args:
|
||||
sqlite_path (str): The sqlite path.
|
||||
engine_args (Optional[Dict], optional): The engine arguments.
|
||||
Defaults to None, if None, we will use connection pool.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
"""
|
||||
if not engine_args:
|
||||
engine_args = {}
|
||||
# Pool class
|
||||
engine_args["poolclass"] = QueuePool
|
||||
# The number of connections to keep open inside the connection pool.
|
||||
engine_args["pool_size"] = 10
|
||||
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
|
||||
# pool_size).
|
||||
engine_args["max_overflow"] = 20
|
||||
# The number of seconds to wait before giving up on getting a connection from the pool.
|
||||
engine_args["pool_timeout"] = 30
|
||||
# Recycle the connection if it has been idle for this many seconds.
|
||||
engine_args["pool_recycle"] = 3600
|
||||
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
|
||||
engine_args["pool_pre_ping"] = True
|
||||
|
||||
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
|
||||
|
||||
def create_all(self):
|
||||
self.Model.metadata.create_all(self._engine)
|
||||
|
||||
|
||||
db = DatabaseManager()
|
||||
"""The global database manager.
|
||||
|
||||
Examples:
|
||||
>>> from dbgpt.storage.metadata import db
|
||||
>>> sqlite_path = "/tmp/dbgpt.db"
|
||||
>>> db.init_default_db(sqlite_path)
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
|
||||
>>> from dbgpt.storage.metadata import db, Model
|
||||
>>> from urllib.parse import quote_plus as urlquote, quote
|
||||
>>> db_name = "dbgpt"
|
||||
>>> db_host = "localhost"
|
||||
>>> db_port = 3306
|
||||
>>> user = "root"
|
||||
>>> password = "123456"
|
||||
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
|
||||
>>> engine_args = {"pool_size": 10, "max_overflow": 20, "pool_timeout": 30, "pool_recycle": 3600, "pool_pre_ping": True}
|
||||
>>> db.init_db(url, engine_args=engine_args)
|
||||
>>> class User(Model):
|
||||
>>> __tablename__ = "user"
|
||||
>>> id = Column(Integer, primary_key=True)
|
||||
>>> name = Column(String(50))
|
||||
>>> fullname = Column(String(50))
|
||||
>>> with db.session() as session:
|
||||
>>> session.add(User(name="test", fullname="test"))
|
||||
>>> session.commit()
|
||||
"""
|
||||
|
||||
|
||||
class BaseCRUDMixin(Generic[T]):
|
||||
"""The base CRUD mixin."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
@classmethod
|
||||
def create(cls: Type[T], **kwargs) -> T:
|
||||
instance = cls(**kwargs)
|
||||
return instance.save()
|
||||
|
||||
@classmethod
|
||||
def all(cls: Type[T]) -> List[T]:
|
||||
return cls.query.all()
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
|
||||
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
|
||||
"""Update specific fields of a record."""
|
||||
for attr, value in kwargs.items():
|
||||
setattr(self, attr, value)
|
||||
return commit and self.save() or self
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
|
||||
|
||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||
|
||||
"""The base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
return db_manager._session().get(cls, ident)
|
||||
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
session = db_manager._session()
|
||||
session.add(self)
|
||||
if commit:
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
session = db_manager._session()
|
||||
session.delete(self)
|
||||
return commit and session.commit()
|
||||
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||
"""Base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
return _NewModel
|
||||
|
||||
|
||||
Model = create_model(db)
|
||||
|
||||
|
||||
def initialize_db(
|
||||
db_url: Union[str, URL],
|
||||
db_name: str,
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
try_to_create_db: Optional[bool] = False,
|
||||
) -> DatabaseManager:
|
||||
"""Initialize the database manager.
|
||||
|
||||
Args:
|
||||
db_url (Union[str, URL]): The database url.
|
||||
db_name (str): The database name.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
db.init_db(db_url, engine_args, base)
|
||||
if try_to_create_db:
|
||||
try:
|
||||
db.create_all()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create database {db_name}: {e}")
|
||||
return db
|
128
dbgpt/storage/metadata/db_storage.py
Normal file
128
dbgpt/storage/metadata/db_storage.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
from typing import Type, List, Optional, Union, Dict
|
||||
from dbgpt.core import Serializer
|
||||
from dbgpt.core.interface.storage import (
|
||||
StorageInterface,
|
||||
QuerySpec,
|
||||
ResourceIdentifier,
|
||||
StorageItemAdapter,
|
||||
T,
|
||||
)
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.orm import Session, DeclarativeMeta
|
||||
|
||||
from .db_manager import BaseModel, DatabaseManager, BaseQuery
|
||||
|
||||
|
||||
def _copy_public_properties(src: BaseModel, dest: BaseModel):
|
||||
"""Simple copy public properties from src to dest"""
|
||||
for column in src.__table__.columns:
|
||||
if column.name != "id":
|
||||
setattr(dest, column.name, getattr(src, column.name))
|
||||
|
||||
|
||||
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
||||
def __init__(
|
||||
self,
|
||||
db_url_or_db: Union[str, URL, DatabaseManager],
|
||||
model_class: Type[BaseModel],
|
||||
adapter: StorageItemAdapter[T, BaseModel],
|
||||
serializer: Optional[Serializer] = None,
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
):
|
||||
super().__init__(serializer=serializer, adapter=adapter)
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
|
||||
self.db_manager = db_manager
|
||||
elif isinstance(db_url_or_db, DatabaseManager):
|
||||
self.db_manager = db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
)
|
||||
self._model_class = model_class
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
with self.db_manager.session() as session:
|
||||
yield session
|
||||
|
||||
def save(self, data: T) -> None:
|
||||
with self.session() as session:
|
||||
model_instance = self.adapter.to_storage_format(data)
|
||||
session.add(model_instance)
|
||||
|
||||
def update(self, data: T) -> None:
|
||||
with self.session() as session:
|
||||
model_instance = self.adapter.to_storage_format(data)
|
||||
session.merge(model_instance)
|
||||
|
||||
def save_or_update(self, data: T) -> None:
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, data.identifier, session=session
|
||||
)
|
||||
model_instance = query.with_session(session).first()
|
||||
if model_instance:
|
||||
new_instance = self.adapter.to_storage_format(data)
|
||||
_copy_public_properties(new_instance, model_instance)
|
||||
session.merge(model_instance)
|
||||
return
|
||||
self.save(data)
|
||||
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, resource_id, session=session
|
||||
)
|
||||
model_instance = query.with_session(session).first()
|
||||
if model_instance:
|
||||
return self.adapter.from_storage_format(model_instance)
|
||||
return None
|
||||
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
with self.session() as session:
|
||||
query = self.adapter.get_query_for_identifier(
|
||||
self._model_class, resource_id, session=session
|
||||
)
|
||||
model_instance = query.with_session(session).first()
|
||||
if model_instance:
|
||||
session.delete(model_instance)
|
||||
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = session.query(self._model_class)
|
||||
for key, value in spec.conditions.items():
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
if spec.limit is not None:
|
||||
query = query.limit(spec.limit)
|
||||
if spec.offset is not None:
|
||||
query = query.offset(spec.offset)
|
||||
model_instances = query.all()
|
||||
return [
|
||||
self.adapter.from_storage_format(instance)
|
||||
for instance in model_instances
|
||||
]
|
||||
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
"""Count the number of data in the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
"""
|
||||
with self.session() as session:
|
||||
query = session.query(self._model_class)
|
||||
for key, value in spec.conditions.items():
|
||||
query = query.filter(getattr(self._model_class, key) == value)
|
||||
return query.count()
|
@@ -1,94 +0,0 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import logging
|
||||
|
||||
from sqlalchemy import create_engine, DDL
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config as AlembicConfig
|
||||
from urllib.parse import quote
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# DB-GPT metadata database config, now support mysql and sqlite
|
||||
CFG = Config()
|
||||
default_db_path = os.path.join(PILOT_PATH, "meta_data")
|
||||
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
|
||||
# Meta Info
|
||||
META_DATA_DATABASE = CFG.LOCAL_DB_NAME
|
||||
db_name = META_DATA_DATABASE
|
||||
db_path = default_db_path + f"/{db_name}.db"
|
||||
connection = sqlite3.connect(db_path)
|
||||
|
||||
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
engine_temp = create_engine(
|
||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}"
|
||||
)
|
||||
# check and auto create mysqldatabase
|
||||
try:
|
||||
# try to connect
|
||||
with engine_temp.connect() as conn:
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
conn.execute(DDL(f"CREATE DATABASE IF NOT EXISTS {db_name}"))
|
||||
print(f"Already connect '{db_name}'")
|
||||
|
||||
except OperationalError as e:
|
||||
# if connect failed, create dbgpt database
|
||||
logger.error(f"{db_name} not connect success!")
|
||||
|
||||
engine = create_engine(
|
||||
f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||
)
|
||||
else:
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Base.metadata.create_all()
|
||||
|
||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||
alembic_cfg = AlembicConfig(alembic_ini_path)
|
||||
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
|
||||
|
||||
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
||||
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
|
||||
|
||||
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
|
||||
|
||||
alembic_cfg.attributes["target_metadata"] = Base.metadata
|
||||
alembic_cfg.attributes["session"] = session
|
||||
|
||||
|
||||
def ddl_init_and_upgrade(disable_alembic_upgrade: bool):
|
||||
"""Initialize and upgrade database metadata
|
||||
|
||||
Args:
|
||||
disable_alembic_upgrade (bool): Whether to enable alembic to initialize and upgrade database metadata
|
||||
"""
|
||||
if disable_alembic_upgrade:
|
||||
logger.info(
|
||||
"disable_alembic_upgrade is true, not to initialize and upgrade database metadata with alembic"
|
||||
)
|
||||
return
|
||||
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
heads = command.heads(alembic_cfg)
|
||||
print("heads:" + str(heads))
|
||||
|
||||
command.revision(alembic_cfg, "dbgpt ddl upate", True)
|
||||
command.upgrade(alembic_cfg, "head")
|
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
0
dbgpt/storage/metadata/tests/__init__.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
129
dbgpt/storage/metadata/tests/test_db_manager.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from __future__ import annotations
|
||||
import pytest
|
||||
from typing import Type
|
||||
from dbgpt.storage.metadata.db_manager import (
|
||||
DatabaseManager,
|
||||
PaginationResult,
|
||||
create_model,
|
||||
BaseModel,
|
||||
)
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
db = DatabaseManager()
|
||||
db.init_db("sqlite:///:memory:")
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def Model(db):
|
||||
return create_model(db)
|
||||
|
||||
|
||||
def test_database_initialization(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
assert db.engine is not None
|
||||
assert db.session is not None
|
||||
|
||||
with db.session() as session:
|
||||
assert session is not None
|
||||
|
||||
|
||||
def test_model_creation(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
assert db.metadata.tables == {}
|
||||
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
assert list(db.metadata.tables.keys())[0] == "user"
|
||||
|
||||
|
||||
def test_crud_operations(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# Create
|
||||
with db.session() as session:
|
||||
user = User.create(name="John Doe")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Read
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
assert user is not None
|
||||
|
||||
# Update
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
user.update(name="Jane Doe")
|
||||
|
||||
# Delete
|
||||
with db.session() as session:
|
||||
user = session.query(User).filter_by(name="Jane Doe").first()
|
||||
user.delete()
|
||||
|
||||
|
||||
def test_crud_mixins(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# Create
|
||||
user = User.create(name="John Doe")
|
||||
assert User.get(user.id) is not None
|
||||
users = User.all()
|
||||
assert len(users) == 1
|
||||
|
||||
# Update
|
||||
user.update(name="Bob Doe")
|
||||
assert User.get(user.id).name == "Bob Doe"
|
||||
|
||||
user = User.get(user.id)
|
||||
user.delete()
|
||||
assert User.get(user.id) is None
|
||||
|
||||
|
||||
def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
# 添加数据
|
||||
with db.session() as session:
|
||||
for i in range(30):
|
||||
user = User(name=f"User {i}")
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
users_page_1 = User.query.paginate_query(page=1, per_page=10)
|
||||
assert len(users_page_1.items) == 10
|
||||
assert users_page_1.total_pages == 3
|
||||
|
||||
|
||||
def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
|
||||
db.create_all()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=0, per_page=10)
|
||||
with pytest.raises(ValueError):
|
||||
User.query.paginate_query(page=1, per_page=-1)
|
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
173
dbgpt/storage/metadata/tests/test_sqlalchemy_storage.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import Dict, Type
|
||||
from sqlalchemy.orm import declarative_base, Session
|
||||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import (
|
||||
StorageItem,
|
||||
ResourceIdentifier,
|
||||
StorageItemAdapter,
|
||||
QuerySpec,
|
||||
)
|
||||
from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage
|
||||
|
||||
from dbgpt.core.interface.tests.test_storage import MockResourceIdentifier
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class MockModel(Base):
|
||||
"""The SQLAlchemy model for the mock data."""
|
||||
|
||||
__tablename__ = "mock_data"
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(String)
|
||||
|
||||
|
||||
class MockStorageItem(StorageItem):
|
||||
"""The mock storage item."""
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
if not isinstance(other, MockStorageItem):
|
||||
raise ValueError("other must be a MockStorageItem")
|
||||
self.data = other.data
|
||||
|
||||
def __init__(self, identifier: ResourceIdentifier, data: str):
|
||||
self._identifier = identifier
|
||||
self.data = data
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
return self._identifier
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"identifier": self._identifier, "data": self.data}
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return str(self.data).encode()
|
||||
|
||||
|
||||
class MockStorageItemAdapter(StorageItemAdapter[MockStorageItem, MockModel]):
|
||||
"""The adapter for the mock storage item."""
|
||||
|
||||
def to_storage_format(self, item: MockStorageItem) -> MockModel:
|
||||
return MockModel(id=int(item.identifier.str_identifier), data=item.data)
|
||||
|
||||
def from_storage_format(self, model: MockModel) -> MockStorageItem:
|
||||
return MockStorageItem(MockResourceIdentifier(str(model.id)), model.data)
|
||||
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[MockModel],
|
||||
resource_id: ResourceIdentifier,
|
||||
**kwargs,
|
||||
):
|
||||
session: Session = kwargs.get("session")
|
||||
if session is None:
|
||||
raise ValueError("session is required for this adapter")
|
||||
return session.query(storage_format).filter(
|
||||
storage_format.id == int(resource_id.str_identifier)
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_url():
|
||||
"""Use in-memory SQLite database for testing"""
|
||||
return "sqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sqlalchemy_storage(db_url, serializer):
|
||||
adapter = MockStorageItemAdapter()
|
||||
storage = SQLAlchemyStorage(db_url, MockModel, adapter, serializer, base=Base)
|
||||
Base.metadata.create_all(storage.db_manager.engine)
|
||||
return storage
|
||||
|
||||
|
||||
def test_save_and_load(sqlalchemy_storage):
|
||||
item = MockStorageItem(MockResourceIdentifier("1"), "test_data")
|
||||
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
loaded_item = sqlalchemy_storage.load(MockResourceIdentifier("1"), MockStorageItem)
|
||||
assert loaded_item.data == "test_data"
|
||||
|
||||
|
||||
def test_delete(sqlalchemy_storage):
|
||||
resource_id = MockResourceIdentifier("1")
|
||||
|
||||
sqlalchemy_storage.delete(resource_id)
|
||||
# Make sure the item is deleted
|
||||
assert sqlalchemy_storage.load(resource_id, MockStorageItem) is None
|
||||
|
||||
|
||||
def test_query_with_various_conditions(sqlalchemy_storage):
|
||||
# Add multiple items for testing
|
||||
for i in range(5):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
# Test query with single condition
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 1
|
||||
assert results[0].data == "test_data_2"
|
||||
|
||||
# Test not existing condition
|
||||
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 0
|
||||
|
||||
# Test query with multiple conditions
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2", "id": "2"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_query_nonexistent_item(sqlalchemy_storage):
|
||||
query_spec = QuerySpec(conditions={"data": "nonexistent"})
|
||||
results = sqlalchemy_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_count_items(sqlalchemy_storage):
|
||||
for i in range(5):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
# Test count without conditions
|
||||
query_spec = QuerySpec(conditions={})
|
||||
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||
assert total_count == 5
|
||||
|
||||
# Test count with conditions
|
||||
query_spec = QuerySpec(conditions={"data": "test_data_2"})
|
||||
total_count = sqlalchemy_storage.count(query_spec, MockStorageItem)
|
||||
assert total_count == 1
|
||||
|
||||
|
||||
def test_paginate_query(sqlalchemy_storage):
|
||||
for i in range(10):
|
||||
item = MockStorageItem(MockResourceIdentifier(str(i)), f"test_data_{i}")
|
||||
sqlalchemy_storage.save(item)
|
||||
|
||||
page_size = 3
|
||||
page_number = 2
|
||||
|
||||
query_spec = QuerySpec(conditions={})
|
||||
page_result = sqlalchemy_storage.paginate_query(
|
||||
page_number, page_size, MockStorageItem, query_spec
|
||||
)
|
||||
|
||||
assert len(page_result.items) == page_size
|
||||
assert page_result.page == page_number
|
||||
assert page_result.total_pages == 4
|
||||
assert page_result.total_count == 10
|
Reference in New Issue
Block a user