DB-GPT/dbgpt/storage/metadata/db_manager.py
lipengfei 88e3d12bd3
feat: add document structure into GraphRAG (#2033)
Co-authored-by: Appointat <kuda.czk@antgroup.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: vritser <vritser@163.com>
2024-10-18 22:03:08 +08:00

544 lines
18 KiB
Python

"""The database manager."""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
from sqlalchemy.orm import (
DeclarativeMeta,
Session,
declarative_base,
scoped_session,
sessionmaker,
)
from sqlalchemy.pool import QueuePool
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.string_utils import _to_str
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
# class _QueryObject:
# """The query object."""
#
# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
# return model_cls.query_class(
# model_cls, session=model_cls.__db_manager__._session()
# )
#
class BaseQuery(orm.Query):
"""Base query class."""
def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult:
"""Paginate the query.
Example:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
pagination = session.query(User).paginate_query(
page=1, page_size=10
)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
per_page (Optional[int], optional): The number of items per page. Defaults
to 20.
Returns:
PaginationResult: The pagination result.
"""
if page < 1:
raise ValueError("Page must be greater than 0")
if per_page < 0:
raise ValueError("Per page must be greater than 0")
items = self.limit(per_page).offset((page - 1) * per_page).all()
total = self.order_by(None).count()
total_pages = (total - 1) // per_page + 1
return PaginationResult(
items=items,
total_count=total,
total_pages=total_pages,
page=page,
page_size=per_page,
)
class _Model:
"""Base class for SQLAlchemy declarative base model."""
__db_manager__: ClassVar[DatabaseManager]
query_class = BaseQuery
# query: Optional[BaseQuery] = _QueryObject()
def __repr__(self):
identity = inspect(self).identity
if identity is None:
pk = "(transient {0})".format(id(self))
else:
pk = ", ".join(_to_str(value) for value in identity)
return "<{0} {1}>".format(type(self).__name__, pk)
class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager()
# Use sqlite with memory storage.
url = f"sqlite:///:memory:"
engine_args = {
"pool_size": 10,
"max_overflow": 20,
"pool_timeout": 30,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
db.init_db(url, engine_args=engine_args)
Model = create_model(db)
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
session.add(User(name="test", fullname="test"))
# db will commit the session automatically default.
# session.commit()
assert (
session.query(User).filter(User.name == "test").first().name
== "test"
)
# More usage:
with db.session() as session:
session.add(User(name="test1", fullname="test1"))
session.add(User(name="test2", fullname="test1"))
users = session.query(User).all()
print(users)
user = users[0]
user.name = "test1_1111"
session.merge(user)
user2 = users[1]
# Update user2 by save
user2.name = "test2_1111"
session.merge(user2)
session.commit()
# Delete user2
user2.delete()
"""
Query = BaseQuery
def __init__(self):
"""Create a DatabaseManager."""
self._db_url = None
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
self._engine: Optional[Engine] = None
self._session: Optional[scoped_session] = None
@property
def Model(self) -> _Model:
"""Get the declarative base."""
return self._base # type: ignore
@property
def metadata(self) -> MetaData:
"""Get the metadata."""
return self.Model.metadata # type: ignore
@property
def engine(self):
"""Get the engine.""" ""
return self._engine
@property
def is_initialized(self) -> bool:
"""Whether the database manager is initialized."""
return self._engine is not None and self._session is not None
@contextmanager
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
"""Get the session with context manager.
This context manager handles the lifecycle of a SQLAlchemy session.
It automatically commits or rolls back transactions based on
the execution and handles session closure.
The `commit` parameter controls whether the session should commit
changes at the end of the block. This is useful for separating
read and write operations.
Examples:
.. code-block:: python
# For write operations (insert, update, delete):
with db.session() as session:
user = User(name="John Doe")
session.add(user)
# session.commit() is called automatically
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read
# operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
If True (default), the session will commit changes at the end
of the block. Use False for read-only operations or when manual
control over commit is needed. Defaults to True.
Yields:
Session: The SQLAlchemy session object.
Raises:
RuntimeError: Raised if the database manager is not initialized.
Exception: Propagates any exception that occurred within the block.
Important Notes:
- DetachedInstanceError: This error occurs when trying to access or
modify an instance that has been detached from its session.
DetachedInstanceError can occur in scenarios where the session is
closed, and further interaction with the ORM object is attempted,
especially when accessing lazy-loaded attributes. To avoid this:
a. Ensure required attributes are loaded before session closure.
b. Avoid closing the session before all necessary interactions
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
"""
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session() # type: ignore
try:
yield session
if commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def _make_declarative_base(
self, model: Union[Type[DeclarativeMeta], Type[_Model]]
) -> DeclarativeMeta:
"""Make the declarative base.
Args:
model (DeclarativeMeta): The base class.
Returns:
DeclarativeMeta: The declarative base.
"""
if not isinstance(model, DeclarativeMeta):
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query # type: ignore
# model.query = _QueryObject()
model.__db_manager__ = self # type: ignore
return model # type: ignore
def init_db(
self,
db_url: Union[str, URL],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
session_options: Optional[Dict] = None,
):
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults
to None.
"""
if session_options is None:
session_options = {}
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
# if not hasattr(base, "query") or override_query_class:
# base.query = _QueryObject()
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
if not hasattr(base, "__db_manager__") or override_query_class:
base.__db_manager__ = self
self._engine = create_engine(db_url, **(engine_args or {}))
session_options.setdefault("class_", Session)
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
# self._session = scoped_session(session_factory)
self._session = session_factory # type: ignore
self._base.metadata.bind = self._engine # type: ignore
def init_default_db(
self,
sqlite_path: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
):
"""Initialize the database manager with default config.
Examples:
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
... session.query(...)
...
Args:
sqlite_path (str): The sqlite path.
engine_args (Optional[Dict], optional): The engine arguments.
Defaults to None, if None, we will use connection pool.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
"""
if not engine_args:
engine_args = {
# Pool class
"poolclass": QueuePool,
# The number of connections to keep open inside the connection pool.
"pool_size": 10,
# The maximum overflow size of the pool when the number of connections
# be used in the pool is exceeded(pool_size).
"max_overflow": 20,
# The number of seconds to wait before giving up on getting a connection
# from the pool.
"pool_timeout": 30,
# Recycle the connection if it has been idle for this many seconds.
"pool_recycle": 3600,
# Enable the connection pool “pre-ping” feature that tests connections
# for liveness upon each checkout.
"pool_pre_ping": True,
}
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
def create_all(self):
"""Create all tables."""
self.Model.metadata.create_all(self._engine)
@staticmethod
def build_from(
db_url_or_db: Union[str, URL, DatabaseManager],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
) -> DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
from sqlalchemy import Column, Integer, String
db = DatabaseManager.build_from("sqlite:///:memory:")
class User(db.Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
db.create_all()
with db.session() as session:
session.add(User(name="test", fullname="test"))
session.commit()
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the
database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, (str, URL)):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
)
return db_manager
elif isinstance(db_url_or_db, DatabaseManager):
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got "
f"{type(db_url_or_db)}"
)
db = DatabaseManager()
"""The global database manager.
Examples:
>>> from dbgpt.storage.metadata import db
>>> sqlite_path = "/tmp/dbgpt.db"
>>> db.init_default_db(sqlite_path)
>>> with db.session() as session:
... session.query(...)
...
>>> from dbgpt.storage.metadata import db, Model
>>> from urllib.parse import quote_plus as urlquote, quote
>>> db_name = "dbgpt"
>>> db_host = "localhost"
>>> db_port = 3306
>>> user = "root"
>>> password = "123456"
>>> url = (
... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}"
... f":{str(db_port)}/{db_name}"
... )
>>> engine_args = {
... "pool_size": 10,
... "max_overflow": 20,
... "pool_timeout": 30,
... "pool_recycle": 3600,
... "pool_pre_ping": True,
... }
>>> db.init_db(url, engine_args=engine_args)
>>> class User(Model):
... __tablename__ = "user"
... id = Column(Integer, primary_key=True)
... name = Column(String(50))
... fullname = Column(String(50))
...
>>> with db.session() as session:
... session.add(User(name="test", fullname="test"))
... session.commit()
...
"""
class BaseCRUDMixin(Generic[T]):
"""The base CRUD mixin."""
__abstract__ = True
@classmethod
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls.__db_manager__ # type: ignore
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
"""Whether the model is abstract."""
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
"""Create a model."""
class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore
"""Mixin that adds convenience methods for CRUD."""
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls._db_manager
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
return _NewModel
Model: Type = create_model(db)
def initialize_db(
db_url: Union[str, URL],
db_name: str,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False,
session_options: Optional[Dict] = None,
) -> DatabaseManager:
"""Initialize the database manager.
Args:
db_url (Union[str, URL]): The database url.
db_name (str): The database name.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the
database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to
None.
Returns:
DatabaseManager: The database manager.
"""
db.init_db(db_url, engine_args, base, session_options=session_options)
if try_to_create_db:
try:
db.create_all()
except Exception as e:
logger.error(f"Failed to create database {db_name}: {e}")
return db