chore: Add pylint for storage (#1298)

This commit is contained in:
Fangyin Cheng
2024-03-15 15:42:46 +08:00
committed by GitHub
parent a207640ff2
commit 8897d6e8fd
50 changed files with 784 additions and 667 deletions

View File

@@ -1,8 +1,9 @@
"""The database manager."""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import ClassVar, Dict, Generic, Optional, Type, TypeVar, Union
from typing import ClassVar, Dict, Generic, Iterator, Optional, Type, TypeVar, Union
from sqlalchemy import URL, Engine, MetaData, create_engine, inspect, orm
from sqlalchemy.orm import (
@@ -21,19 +22,20 @@ logger = logging.getLogger(__name__)
T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
return model_cls.query_class(
model_cls, session=model_cls.__db_manager__._session()
)
# class _QueryObject:
# """The query object."""
#
# def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
# return model_cls.query_class(
# model_cls, session=model_cls.__db_manager__._session()
# )
#
class BaseQuery(orm.Query):
def paginate_query(
self, page: Optional[int] = 1, per_page: Optional[int] = 20
) -> PaginationResult:
"""Base query class."""
def paginate_query(self, page: int = 1, per_page: int = 20) -> PaginationResult:
"""Paginate the query.
Example:
@@ -56,10 +58,10 @@ class BaseQuery(orm.Query):
)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
per_page (Optional[int], optional): The number of items per page. Defaults to 20.
per_page (Optional[int], optional): The number of items per page. Defaults
to 20.
Returns:
PaginationResult: The pagination result.
"""
@@ -100,7 +102,6 @@ class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
@@ -161,6 +162,7 @@ class DatabaseManager:
Query = BaseQuery
def __init__(self):
"""Create a DatabaseManager."""
self._db_url = None
self._base: DeclarativeMeta = self._make_declarative_base(_Model)
self._engine: Optional[Engine] = None
@@ -169,12 +171,12 @@ class DatabaseManager:
@property
def Model(self) -> _Model:
"""Get the declarative base."""
return self._base
return self._base # type: ignore
@property
def metadata(self) -> MetaData:
"""Get the metadata."""
return self.Model.metadata
return self.Model.metadata # type: ignore
@property
def engine(self):
@@ -183,11 +185,11 @@ class DatabaseManager:
@property
def is_initialized(self) -> bool:
"""Whether the database manager is initialized.""" ""
"""Whether the database manager is initialized."""
return self._engine is not None and self._session is not None
@contextmanager
def session(self, commit: Optional[bool] = True) -> Session:
def session(self, commit: Optional[bool] = True) -> Iterator[Session]:
"""Get the session with context manager.
This context manager handles the lifecycle of a SQLAlchemy session.
@@ -199,7 +201,6 @@ class DatabaseManager:
read and write operations.
Examples:
.. code-block:: python
# For write operations (insert, update, delete):
@@ -211,7 +212,8 @@ class DatabaseManager:
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read operations
# session.commit() is NOT called, as it's unnecessary for read
# operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
@@ -237,16 +239,15 @@ class DatabaseManager:
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
"""
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
session = self._session() # type: ignore
try:
yield session
if commit:
session.commit()
except:
except Exception:
session.rollback()
raise
finally:
@@ -266,10 +267,10 @@ class DatabaseManager:
if not isinstance(model, DeclarativeMeta):
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query_class = self.Query # type: ignore
# model.query = _QueryObject()
model.__db_manager__ = self
return model
model.__db_manager__ = self # type: ignore
return model # type: ignore
def init_db(
self,
@@ -284,11 +285,14 @@ class DatabaseManager:
Args:
db_url (Union[str, URL]): The database url.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults
to None.
"""
if session_options is None:
session_options = {}
@@ -309,8 +313,8 @@ class DatabaseManager:
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
# self._session = scoped_session(session_factory)
self._session = session_factory
self._base.metadata.bind = self._engine
self._session = session_factory # type: ignore
self._base.metadata.bind = self._engine # type: ignore
def init_default_db(
self,
@@ -333,24 +337,28 @@ class DatabaseManager:
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
"""
if not engine_args:
engine_args = {}
# Pool class
engine_args["poolclass"] = QueuePool
# The number of connections to keep open inside the connection pool.
engine_args["pool_size"] = 10
# The maximum overflow size of the pool when the number of connections be used in the pool is exceeded(
# pool_size).
engine_args["max_overflow"] = 20
# The number of seconds to wait before giving up on getting a connection from the pool.
engine_args["pool_timeout"] = 30
# Recycle the connection if it has been idle for this many seconds.
engine_args["pool_recycle"] = 3600
# Enable the connection pool “pre-ping” feature that tests connections for liveness upon each checkout.
engine_args["pool_pre_ping"] = True
engine_args = {
# Pool class
"poolclass": QueuePool,
# The number of connections to keep open inside the connection pool.
"pool_size": 10,
# The maximum overflow size of the pool when the number of connections
# be used in the pool is exceeded(pool_size).
"max_overflow": 20,
# The number of seconds to wait before giving up on getting a connection
# from the pool.
"pool_timeout": 30,
# Recycle the connection if it has been idle for this many seconds.
"pool_recycle": 3600,
# Enable the connection poolpre-ping” feature that tests connections
# for liveness upon each checkout.
"pool_pre_ping": True,
}
self.init_db(f"sqlite:///{sqlite_path}", engine_args, base)
def create_all(self):
"""Create all tables."""
self.Model.metadata.create_all(self._engine)
@staticmethod
@@ -364,9 +372,7 @@ class DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
@@ -389,16 +395,19 @@ class DatabaseManager:
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the
database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to
None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
override_query_class (Optional[bool], optional): Whether to override the
query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
if isinstance(db_url_or_db, (str, URL)):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
@@ -408,7 +417,8 @@ class DatabaseManager:
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
f"db_url_or_db should be either url or a DatabaseManager, got "
f"{type(db_url_or_db)}"
)
@@ -422,7 +432,6 @@ Examples:
>>> with db.session() as session:
... session.query(...)
...
>>> from dbgpt.storage.metadata import db, Model
>>> from urllib.parse import quote_plus as urlquote, quote
>>> db_name = "dbgpt"
@@ -430,7 +439,10 @@ Examples:
>>> db_port = 3306
>>> user = "root"
>>> password = "123456"
>>> url = f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}:{str(db_port)}/{db_name}"
>>> url = (
... f"mysql+pymysql://{quote(user)}:{urlquote(password)}@{db_host}"
... f":{str(db_port)}/{db_name}"
... )
>>> engine_args = {
... "pool_size": 10,
... "max_overflow": 20,
@@ -460,18 +472,21 @@ class BaseCRUDMixin(Generic[T]):
@classmethod
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls.__db_manager__
return cls.__db_manager__ # type: ignore
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
"""Whether the model is abstract."""
def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
"""Create a model."""
class CRUDMixin(BaseCRUDMixin[T], Generic[T]): # type: ignore
"""Mixin that adds convenience methods for CRUD."""
_db_manager: DatabaseManager = db_manager
@@ -485,7 +500,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
"""Get the database manager."""
return cls._db_manager
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]): # type: ignore
"""Base model class that includes CRUD convenience methods."""
__abstract__ = True
@@ -493,7 +508,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
return _NewModel
Model = create_model(db)
Model: Type = create_model(db)
def initialize_db(
@@ -511,8 +526,10 @@ def initialize_db(
db_name (str): The database name.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the
database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to
None.
Returns:
DatabaseManager: The database manager.
"""