mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
chore: Add pylint for storage (#1298)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
"""The database manager."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
|
||||
from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union
|
||||
|
||||
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
|
||||
from sqlalchemy.orm import (
|
||||
@@ -21,19 +22,20 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class _QueryObject:
|
||||
"""The query object."""
|
||||
|
||||
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
return model_cls.query_class(
|
||||
model_cls, session=model_cls.__db_manager__._session()
|
||||
)
|
||||
# class _QueryObject:
|
||||
# """The query object."""
|
||||
#
|
||||
# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
# return model_cls.query_class(
|
||||
# model_cls, session=model_cls.__db_manager__._session()
|
||||
# )
|
||||
#
|
||||
|
||||
|
||||
class BaseQuery(orm.Query):
|
||||
def paginate_query(
|
||||
self, page: Optional[int] = 1, per_page: Optional[int] = 20
|
||||
) -> PaginationResult:
|
||||
"""Base query class."""
|
||||
|
||||
def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult:
|
||||
"""Paginate the query.
|
||||
|
||||
Example:
|
||||
@@ -56,10 +58,10 @@ class BaseQuery(orm.Query):
|
||||
)
|
||||
print(pagination)
|
||||
|
||||
|
||||
Args:
|
||||
page (Optional[int], optional): The page number. Defaults to 1.
|
||||
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
|
||||
per_page (Optional[int], optional): The number of items per page. Defaults
|
||||
to 20.
|
||||
Returns:
|
||||
PaginationResult: The pagination result.
|
||||
"""
|
||||
@@ -100,7 +102,6 @@ class DatabaseManager:
|
||||
"""The database manager.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
@@ -161,6 +162,7 @@ class DatabaseManager:
|
||||
Query = BaseQuery
|
||||
|
||||
def __init__(self):
|
||||
"""Create a DatabaseManager."""
|
||||
self._db_url = None
|
||||
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
|
||||
self._engine: Optional[Engine] = None
|
||||
@@ -169,12 +171,12 @@ class DatabaseManager:
|
||||
@property
|
||||
def Model(self) -> _Model:
|
||||
"""Get the declarative base."""
|
||||
return self._base
|
||||
return self._base # type: ignore
|
||||
|
||||
@property
|
||||
def metadata(self) -> MetaData:
|
||||
"""Get the metadata."""
|
||||
return self.Model.metadata
|
||||
return self.Model.metadata # type: ignore
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
@@ -183,11 +185,11 @@ class DatabaseManager:
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Whether the database manager is initialized.""" ""
|
||||
"""Whether the database manager is initialized."""
|
||||
return self._engine is not None and self._session is not None
|
||||
|
||||
@contextmanager
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
|
||||
"""Get the session with context manager.
|
||||
|
||||
This context manager handles the lifecycle of a SQLAlchemy session.
|
||||
@@ -199,7 +201,6 @@ class DatabaseManager:
|
||||
read and write operations.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# For write operations (insert, update, delete):
|
||||
@@ -211,7 +212,8 @@ class DatabaseManager:
|
||||
# For read-only operations:
|
||||
with db.session(commit=False) as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
# session.commit() is NOT called, as it's unnecessary for read operations
|
||||
# session.commit() is NOT called, as it's unnecessary for read
|
||||
# operations
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session.
|
||||
@@ -237,16 +239,15 @@ class DatabaseManager:
|
||||
with the ORM object are complete.
|
||||
c. Re-bind the instance to a new session if further interaction
|
||||
is required after the session is closed.
|
||||
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("The database manager is not initialized.")
|
||||
session = self._session()
|
||||
session = self._session() # type: ignore
|
||||
try:
|
||||
yield session
|
||||
if commit:
|
||||
session.commit()
|
||||
except:
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
@@ -266,10 +267,10 @@ class DatabaseManager:
|
||||
if not isinstance(model, DeclarativeMeta):
|
||||
model = declarative_base(cls=model, name="Model")
|
||||
if not getattr(model, "query_class", None):
|
||||
model.query_class = self.Query
|
||||
model.query_class = self.Query # type: ignore
|
||||
# model.query = _QueryObject()
|
||||
model.__db_manager__ = self
|
||||
return model
|
||||
model.__db_manager__ = self # type: ignore
|
||||
return model # type: ignore
|
||||
|
||||
def init_db(
|
||||
self,
|
||||
@@ -284,11 +285,14 @@ class DatabaseManager:
|
||||
|
||||
Args:
|
||||
db_url (Union[str, URL]): The database url.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
|
||||
None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
override_query_class (Optional[bool], optional): Whether to override the
|
||||
query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults
|
||||
to None.
|
||||
"""
|
||||
if session_options is None:
|
||||
session_options = {}
|
||||
@@ -309,8 +313,8 @@ class DatabaseManager:
|
||||
session_options.setdefault("query_cls", self.Query)
|
||||
session_factory = sessionmaker(bind=self._engine, **session_options)
|
||||
# self._session = scoped_session(session_factory)
|
||||
self._session = session_factory
|
||||
self._base.metadata.bind = self._engine
|
||||
self._session = session_factory # type: ignore
|
||||
self._base.metadata.bind = self._engine # type: ignore
|
||||
|
||||
def init_default_db(
|
||||
self,
|
||||
@@ -333,24 +337,28 @@ class DatabaseManager:
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
"""
|
||||
if not engine_args:
|
||||
engine_args = {}
|
||||
# Pool class
|
||||
engine_args["poolclass"] = QueuePool
|
||||
# The number of connections to keep open inside the connection pool.
|
||||
engine_args["pool_size"] = 10
|
||||
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
|
||||
# pool_size).
|
||||
engine_args["max_overflow"] = 20
|
||||
# The number of seconds to wait before giving up on getting a connection from the pool.
|
||||
engine_args["pool_timeout"] = 30
|
||||
# Recycle the connection if it has been idle for this many seconds.
|
||||
engine_args["pool_recycle"] = 3600
|
||||
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
|
||||
engine_args["pool_pre_ping"] = True
|
||||
engine_args = {
|
||||
# Pool class
|
||||
"poolclass": QueuePool,
|
||||
# The number of connections to keep open inside the connection pool.
|
||||
"pool_size": 10,
|
||||
# The maximum overflow size of the pool when the number of connections
|
||||
# be used in the pool is exceeded(pool_size).
|
||||
"max_overflow": 20,
|
||||
# The number of seconds to wait before giving up on getting a connection
|
||||
# from the pool.
|
||||
"pool_timeout": 30,
|
||||
# Recycle the connection if it has been idle for this many seconds.
|
||||
"pool_recycle": 3600,
|
||||
# Enable the connection pool “pre-ping” feature that tests connections
|
||||
# for liveness upon each checkout.
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
|
||||
|
||||
def create_all(self):
|
||||
"""Create all tables."""
|
||||
self.Model.metadata.create_all(self._engine)
|
||||
|
||||
@staticmethod
|
||||
@@ -364,9 +372,7 @@ class DatabaseManager:
|
||||
"""Build the database manager from the db_url_or_db.
|
||||
|
||||
Examples:
|
||||
|
||||
Build from the database url.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
@@ -389,16 +395,19 @@ class DatabaseManager:
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
|
||||
Args:
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the
|
||||
database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
|
||||
None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
override_query_class (Optional[bool], optional): Whether to override the
|
||||
query class. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
if isinstance(db_url_or_db, (str, URL)):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(
|
||||
db_url_or_db, engine_args, base, query_class, override_query_class
|
||||
@@ -408,7 +417,8 @@ class DatabaseManager:
|
||||
return db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got "
|
||||
f"{type(db_url_or_db)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -422,7 +432,6 @@ Examples:
|
||||
>>> with db.session() as session:
|
||||
... session.query(...)
|
||||
...
|
||||
|
||||
>>> from dbgpt.storage.metadata import db, Model
|
||||
>>> from urllib.parse import quote_plus as urlquote, quote
|
||||
>>> db_name = "dbgpt"
|
||||
@@ -430,7 +439,10 @@ Examples:
|
||||
>>> db_port = 3306
|
||||
>>> user = "root"
|
||||
>>> password = "123456"
|
||||
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
|
||||
>>> url = (
|
||||
... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}"
|
||||
... f":{str(db_port)}/{db_name}"
|
||||
... )
|
||||
>>> engine_args = {
|
||||
... "pool_size": 10,
|
||||
... "max_overflow": 20,
|
||||
@@ -460,18 +472,21 @@ class BaseCRUDMixin(Generic[T]):
|
||||
@classmethod
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls.__db_manager__
|
||||
return cls.__db_manager__ # type: ignore
|
||||
|
||||
|
||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||
"""The base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
"""Whether the model is abstract."""
|
||||
|
||||
|
||||
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||
"""Create a model."""
|
||||
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore
|
||||
"""Mixin that adds convenience methods for CRUD."""
|
||||
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@@ -485,7 +500,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
"""Get the database manager."""
|
||||
return cls._db_manager
|
||||
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore
|
||||
"""Base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -493,7 +508,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
return _NewModel
|
||||
|
||||
|
||||
Model = create_model(db)
|
||||
Model: Type = create_model(db)
|
||||
|
||||
|
||||
def initialize_db(
|
||||
@@ -511,8 +526,10 @@ def initialize_db(
|
||||
db_name (str): The database name.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the
|
||||
database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to
|
||||
None.
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user