mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
feat(model): Proxy model support count token (#996)
This commit is contained in:
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar, Generic, Union, Dict, Optional, Type, Iterator, List
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Generic,
|
||||
Union,
|
||||
Dict,
|
||||
Optional,
|
||||
Type,
|
||||
ClassVar,
|
||||
)
|
||||
import logging
|
||||
from sqlalchemy import create_engine, URL, Engine
|
||||
from sqlalchemy import orm, inspect, MetaData
|
||||
@@ -13,8 +20,6 @@ from sqlalchemy.orm import (
|
||||
declarative_base,
|
||||
DeclarativeMeta,
|
||||
)
|
||||
from sqlalchemy.orm.session import _PKIdentityArgument
|
||||
from sqlalchemy.orm.exc import UnmappedClassError
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
from dbgpt.util.string_utils import _to_str
|
||||
@@ -27,16 +32,10 @@ T = TypeVar("T", bound="BaseModel")
|
||||
class _QueryObject:
|
||||
"""The query object."""
|
||||
|
||||
def __init__(self, db_manager: "DatabaseManager"):
|
||||
self._db_manager = db_manager
|
||||
|
||||
def __get__(self, obj, type):
|
||||
try:
|
||||
mapper = orm.class_mapper(type)
|
||||
if mapper:
|
||||
return type.query_class(mapper, session=self._db_manager._session())
|
||||
except UnmappedClassError:
|
||||
return None
|
||||
def __get__(self, obj: Union[_Model, None], model_cls: type[_Model]):
|
||||
return model_cls.query_class(
|
||||
model_cls, session=model_cls.__db_manager__._session()
|
||||
)
|
||||
|
||||
|
||||
class BaseQuery(orm.Query):
|
||||
@@ -46,7 +45,9 @@ class BaseQuery(orm.Query):
|
||||
"""Paginate the query.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
@@ -58,10 +59,6 @@ class BaseQuery(orm.Query):
|
||||
pagination = session.query(User).paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
# Or you can use the query object
|
||||
with db.session() as session:
|
||||
pagination = User.query.paginate_query(page=1, page_size=10)
|
||||
print(pagination)
|
||||
|
||||
Args:
|
||||
page (Optional[int], optional): The page number. Defaults to 1.
|
||||
@@ -86,26 +83,12 @@ class BaseQuery(orm.Query):
|
||||
|
||||
|
||||
class _Model:
|
||||
"""Base class for SQLAlchemy declarative base model.
|
||||
"""Base class for SQLAlchemy declarative base model."""
|
||||
|
||||
With this class, we can use the query object to query the database.
|
||||
__db_manager__: ClassVar[DatabaseManager]
|
||||
query_class = BaseQuery
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
from dbgpt.storage.metadata import db, Model
|
||||
class User(Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
|
||||
with db.session() as session:
|
||||
# User is an instance of _Model, and we can use the query object to query the database.
|
||||
User.query.filter(User.name == "test").all()
|
||||
"""
|
||||
|
||||
query_class = None
|
||||
query: Optional[BaseQuery] = None
|
||||
# query: Optional[BaseQuery] = _QueryObject()
|
||||
|
||||
def __repr__(self):
|
||||
identity = inspect(self).identity
|
||||
@@ -120,7 +103,9 @@ class DatabaseManager:
|
||||
"""The database manager.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from urllib.parse import quote_plus as urlquote, quote
|
||||
from dbgpt.storage.metadata import DatabaseManager, create_model
|
||||
db = DatabaseManager()
|
||||
@@ -141,21 +126,25 @@ class DatabaseManager:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
# db will commit the session automatically default.
|
||||
# session.commit()
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
assert session.query(User).filter(User.name == "test").first().name == "test"
|
||||
|
||||
|
||||
# Use CURDMixin APIs to create, update, delete, query the database.
|
||||
# More usage:
|
||||
|
||||
with db.session() as session:
|
||||
User.create(**{"name": "test1", "fullname": "test1"})
|
||||
User.create(**{"name": "test2", "fullname": "test1"})
|
||||
users = User.all()
|
||||
session.add(User(name="test1", fullname="test1"))
|
||||
session.add(User(name="test2", fullname="test1"))
|
||||
users = session.query(User).all()
|
||||
print(users)
|
||||
user = users[0]
|
||||
user.update(**{"name": "test1_1111"})
|
||||
user.name = "test1_1111"
|
||||
session.merge(user)
|
||||
|
||||
user2 = users[1]
|
||||
# Update user2 by save
|
||||
user2.name = "test2_1111"
|
||||
user2.save()
|
||||
session.merge(user2)
|
||||
session.commit()
|
||||
# Delete user2
|
||||
user2.delete()
|
||||
"""
|
||||
@@ -189,28 +178,65 @@ class DatabaseManager:
|
||||
return self._engine is not None and self._session is not None
|
||||
|
||||
@contextmanager
|
||||
def session(self) -> Session:
|
||||
def session(self, commit: Optional[bool] = True) -> Session:
|
||||
"""Get the session with context manager.
|
||||
|
||||
If raise any exception, the session will roll back automatically, otherwise, the session will commit automatically.
|
||||
This context manager handles the lifecycle of a SQLAlchemy session.
|
||||
It automatically commits or rolls back transactions based on
|
||||
the execution and handles session closure.
|
||||
|
||||
Example:
|
||||
>>> with db.session() as session:
|
||||
>>> session.query(...)
|
||||
The `commit` parameter controls whether the session should commit
|
||||
changes at the end of the block. This is useful for separating
|
||||
read and write operations.
|
||||
|
||||
Returns:
|
||||
Session: The session.
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# For write operations (insert, update, delete):
|
||||
with db.session() as session:
|
||||
user = User(name="John Doe")
|
||||
session.add(user)
|
||||
# session.commit() is called automatically
|
||||
|
||||
# For read-only operations:
|
||||
with db.session(commit=False) as session:
|
||||
user = session.query(User).filter_by(name="John Doe").first()
|
||||
# session.commit() is NOT called, as it's unnecessary for read operations
|
||||
|
||||
Args:
|
||||
commit (Optional[bool], optional): Whether to commit the session.
|
||||
If True (default), the session will commit changes at the end
|
||||
of the block. Use False for read-only operations or when manual
|
||||
control over commit is needed. Defaults to True.
|
||||
|
||||
Yields:
|
||||
Session: The SQLAlchemy session object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: The database manager is not initialized.
|
||||
Exception: Any exception.
|
||||
RuntimeError: Raised if the database manager is not initialized.
|
||||
Exception: Propagates any exception that occurred within the block.
|
||||
|
||||
Important Notes:
|
||||
- DetachedInstanceError: This error occurs when trying to access or
|
||||
modify an instance that has been detached from its session.
|
||||
DetachedInstanceError can occur in scenarios where the session is
|
||||
closed, and further interaction with the ORM object is attempted,
|
||||
especially when accessing lazy-loaded attributes. To avoid this:
|
||||
a. Ensure required attributes are loaded before session closure.
|
||||
b. Avoid closing the session before all necessary interactions
|
||||
with the ORM object are complete.
|
||||
c. Re-bind the instance to a new session if further interaction
|
||||
is required after the session is closed.
|
||||
|
||||
"""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("The database manager is not initialized.")
|
||||
session = self._session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
if commit:
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
raise
|
||||
@@ -223,7 +249,7 @@ class DatabaseManager:
|
||||
"""Make the declarative base.
|
||||
|
||||
Args:
|
||||
base (DeclarativeMeta): The base class.
|
||||
model (DeclarativeMeta): The base class.
|
||||
|
||||
Returns:
|
||||
DeclarativeMeta: The declarative base.
|
||||
@@ -232,7 +258,8 @@ class DatabaseManager:
|
||||
model = declarative_base(cls=model, name="Model")
|
||||
if not getattr(model, "query_class", None):
|
||||
model.query_class = self.Query
|
||||
model.query = _QueryObject(self)
|
||||
# model.query = _QueryObject()
|
||||
model.__db_manager__ = self
|
||||
return model
|
||||
|
||||
def init_db(
|
||||
@@ -242,6 +269,7 @@ class DatabaseManager:
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
session_options: Optional[Dict] = None,
|
||||
):
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -251,18 +279,26 @@ class DatabaseManager:
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
"""
|
||||
if session_options is None:
|
||||
session_options = {}
|
||||
self._db_url = db_url
|
||||
if query_class is not None:
|
||||
self.Query = query_class
|
||||
if base is not None:
|
||||
self._base = base
|
||||
if not hasattr(base, "query") or override_query_class:
|
||||
base.query = _QueryObject(self)
|
||||
# if not hasattr(base, "query") or override_query_class:
|
||||
# base.query = _QueryObject()
|
||||
if not getattr(base, "query_class", None) or override_query_class:
|
||||
base.query_class = self.Query
|
||||
if not hasattr(base, "__db_manager__") or override_query_class:
|
||||
base.__db_manager__ = self
|
||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||
session_factory = sessionmaker(bind=self._engine)
|
||||
|
||||
session_options.setdefault("class_", Session)
|
||||
session_options.setdefault("query_cls", self.Query)
|
||||
session_factory = sessionmaker(bind=self._engine, **session_options)
|
||||
self._session = scoped_session(session_factory)
|
||||
self._base.metadata.bind = self._engine
|
||||
|
||||
@@ -397,35 +433,12 @@ class BaseCRUDMixin(Generic[T]):
|
||||
__abstract__ = True
|
||||
|
||||
@classmethod
|
||||
def create(cls: Type[T], **kwargs) -> T:
|
||||
instance = cls(**kwargs)
|
||||
return instance.save()
|
||||
|
||||
@classmethod
|
||||
def all(cls: Type[T]) -> List[T]:
|
||||
return cls.query.all()
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
|
||||
def update(self: T, commit: Optional[bool] = True, **kwargs) -> T:
|
||||
"""Update specific fields of a record."""
|
||||
for attr, value in kwargs.items():
|
||||
setattr(self, attr, value)
|
||||
return commit and self.save() or self
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls.__db_manager__
|
||||
|
||||
|
||||
class BaseModel(BaseCRUDMixin[T], _Model, Generic[T]):
|
||||
|
||||
"""The base model class that includes CRUD convenience methods."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -438,28 +451,14 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@classmethod
|
||||
def set_db_manager(cls, db_manager: DatabaseManager):
|
||||
def set_db(cls, db_manager: DatabaseManager):
|
||||
# TODO: It is hard to replace to user DB Connection
|
||||
cls._db_manager = db_manager
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
return cls._db_manager._session().get(cls, ident)
|
||||
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
session = self._db_manager._session()
|
||||
session.add(self)
|
||||
if commit:
|
||||
session.commit()
|
||||
return self
|
||||
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
session = self._db_manager._session()
|
||||
session.delete(self)
|
||||
return commit and session.commit()
|
||||
def db(cls) -> DatabaseManager:
|
||||
"""Get the database manager."""
|
||||
return cls._db_manager
|
||||
|
||||
class _NewModel(CRUDMixin[T], db_manager.Model, Generic[T]):
|
||||
"""Base model class that includes CRUD convenience methods."""
|
||||
@@ -478,6 +477,7 @@ def initialize_db(
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
try_to_create_db: Optional[bool] = False,
|
||||
session_options: Optional[Dict] = None,
|
||||
) -> DatabaseManager:
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -487,10 +487,11 @@ def initialize_db(
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
try_to_create_db (Optional[bool], optional): Whether to try to create the database. Defaults to False.
|
||||
session_options (Optional[Dict], optional): The session options. Defaults to None.
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
db.init_db(db_url, engine_args, base)
|
||||
db.init_db(db_url, engine_args, base, session_options=session_options)
|
||||
if try_to_create_db:
|
||||
try:
|
||||
db.create_all()
|
||||
|
Reference in New Issue
Block a user