feat(model): Proxy model support count token (#996)

This commit is contained in:
Fangyin Cheng
2023-12-29 12:01:31 +08:00
committed by GitHub
parent ba0599ebf4
commit 0cdc77abb2
16 changed files with 366 additions and 248 deletions

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
import abc
from contextlib import contextmanager
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
from typing import (
TypeVar,
Generic,
Union,
Dict,
Optional,
Type,
ClassVar,
)
import logging
from sqlalchemy import create_engine, URL, Engine
from sqlalchemy import orm, inspect, MetaData
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
declarative_base,
DeclarativeMeta,
)
from sqlalchemy.orm.session import _PKIdentityArgument
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.pool import QueuePool
from dbgpt.util.string_utils import _to_str
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
class _QueryObject:
"""The query object."""
def __init__(self, db_manager: "DatabaseManager"):
self._db_manager = db_manager
def __get__(self, obj, type):
try:
mapper = orm.class_mapper(type)
if mapper:
return type.query_class(mapper, session=self._db_manager._session())
except UnmappedClassError:
return None
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
return model_cls.query_class(
model_cls, session=model_cls.__db_manager__._session()
)
class BaseQuery(orm.Query):
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
"""Paginate the query.
Example:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
pagination = session.query(User).paginate_query(page=1, page_size=10)
print(pagination)
# Or you can use the query object
with db.session() as session:
pagination = User.query.paginate_query(page=1, page_size=10)
print(pagination)
Args:
page (Optional[int], optional): The page number. Defaults to 1.
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
class _Model:
"""Base class for SQLAlchemy declarative base model.
"""Base class for SQLAlchemy declarative base model."""
With this class, we can use the query object to query the database.
__db_manager__: ClassVar[DatabaseManager]
query_class = BaseQuery
Examples:
.. code-block:: python
from dbgpt.storage.metadata import db, Model
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
with db.session() as session:
# User is an instance of _Model, and we can use the query object to query the database.
User.query.filter(User.name == "test").all()
"""
query_class = None
query: Optional[BaseQuery] = None
# query: Optional[BaseQuery] = _QueryObject()
def __repr__(self):
identity = inspect(self).identity
@@ -120,7 +103,9 @@ class DatabaseManager:
"""The database manager.
Examples:
.. code-block:: python
from urllib.parse import quote_plus as urlquote, quote
from dbgpt.storage.metadata import DatabaseManager, create_model
db = DatabaseManager()
@@ -141,21 +126,25 @@ class DatabaseManager:
session.add(User(name="test", fullname="test"))
# db will commit the session automatically default.
# session.commit()
print(User.query.filter(User.name == "test").all())
assert session.query(User).filter(User.name == "test").first().name == "test"
# Use CURDMixin APIs to create, update, delete, query the database.
# More usage:
with db.session() as session:
User.create(**{"name": "test1", "fullname": "test1"})
User.create(**{"name": "test2", "fullname": "test1"})
users = User.all()
session.add(User(name="test1", fullname="test1"))
session.add(User(name="test2", fullname="test1"))
users = session.query(User).all()
print(users)
user = users[0]
user.update(**{"name": "test1_1111"})
user.name = "test1_1111"
session.merge(user)
user2 = users[1]
# Update user2 by save
user2.name = "test2_1111"
user2.save()
session.merge(user2)
session.commit()
# Delete user2
user2.delete()
"""
@@ -189,28 +178,65 @@ class DatabaseManager:
return self._engine is not None and self._session is not None
@contextmanager
def session(self) -> Session:
def session(self, commit: Optional[bool] = True) -> Session:
"""Get the session with context manager.
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
This context manager handles the lifecycle of a SQLAlchemy session.
It automatically commits or rolls back transactions based on
the execution and handles session closure.
Example:
>>> with db.session() as session:
>>> session.query(...)
The `commit` parameter controls whether the session should commit
changes at the end of the block. This is useful for separating
read and write operations.
Returns:
Session: The session.
Examples:
.. code-block:: python
# For write operations (insert, update, delete):
with db.session() as session:
user = User(name="John Doe")
session.add(user)
# session.commit() is called automatically
# For read-only operations:
with db.session(commit=False) as session:
user = session.query(User).filter_by(name="John Doe").first()
# session.commit() is NOT called, as it's unnecessary for read operations
Args:
commit (Optional[bool], optional): Whether to commit the session.
If True (default), the session will commit changes at the end
of the block. Use False for read-only operations or when manual
control over commit is needed. Defaults to True.
Yields:
Session: The SQLAlchemy session object.
Raises:
RuntimeError: The database manager is not initialized.
Exception: Any exception.
RuntimeError: Raised if the database manager is not initialized.
Exception: Propagates any exception that occurred within the block.
Important Notes:
- DetachedInstanceError: This error occurs when trying to access or
modify an instance that has been detached from its session.
DetachedInstanceError can occur in scenarios where the session is
closed, and further interaction with the ORM object is attempted,
especially when accessing lazy-loaded attributes. To avoid this:
a. Ensure required attributes are loaded before session closure.
b. Avoid closing the session before all necessary interactions
with the ORM object are complete.
c. Re-bind the instance to a new session if further interaction
is required after the session is closed.
"""
if not self.is_initialized:
raise RuntimeError("The database manager is not initialized.")
session = self._session()
try:
yield session
session.commit()
if commit:
session.commit()
except:
session.rollback()
raise
@@ -223,7 +249,7 @@ class DatabaseManager:
"""Make the declarative base.
Args:
base (DeclarativeMeta): The base class.
model (DeclarativeMeta): The base class.
Returns:
DeclarativeMeta: The declarative base.
@@ -232,7 +258,8 @@ class DatabaseManager:
model = declarative_base(cls=model, name="Model")
if not getattr(model, "query_class", None):
model.query_class = self.Query
model.query = _QueryObject(self)
# model.query = _QueryObject()
model.__db_manager__ = self
return model
def init_db(
@@ -242,6 +269,7 @@ class DatabaseManager:
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
session_options: Optional[Dict] = None,
):
"""Initialize the database manager.
@@ -251,18 +279,26 @@ class DatabaseManager:
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
"""
if session_options is None:
session_options = {}
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self)
# if not hasattr(base, "query") or override_query_class:
# base.query = _QueryObject()
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
if not hasattr(base, "__db_manager__") or override_query_class:
base.__db_manager__ = self
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
session_options.setdefault("class_", Session)
session_options.setdefault("query_cls", self.Query)
session_factory = sessionmaker(bind=self._engine, **session_options)
self._session = scoped_session(session_factory)
self._base.metadata.bind = self._engine
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
__abstract__ = True
@classmethod
def create(cls: Type[T], **kwargs) -> T:
instance = cls(**kwargs)
return instance.save()
@classmethod
def all(cls: Type[T]) -> List[T]:
return cls.query.all()
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return commit and self.save() or self
@abc.abstractmethod
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
@abc.abstractmethod
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls.__db_manager__
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
"""The base model class that includes CRUD convenience methods."""
__abstract__ = True
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db_manager(cls, db_manager: DatabaseManager):
def set_db(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return cls._db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
return self
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()
def db(cls) -> DatabaseManager:
"""Get the database manager."""
return cls._db_manager
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
"""Base model class that includes CRUD convenience methods."""
@@ -478,6 +477,7 @@ def initialize_db(
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
try_to_create_db: Optional[bool] = False,
session_options: Optional[Dict] = None,
) -> DatabaseManager:
"""Initialize the database manager.
@@ -487,10 +487,11 @@ def initialize_db(
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
session_options (Optional[Dict], optional): The session options. Defaults to None.
Returns:
DatabaseManager: The database manager.
"""
db.init_db(db_url, engine_args, base)
db.init_db(db_url, engine_args, base, session_options=session_options)
if try_to_create_db:
try:
db.create_all()