mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 11:01:09 +00:00
feat(core): More AWEL operators and new prompt manager API (#972)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -236,6 +236,7 @@ class DatabaseManager:
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
):
|
||||
"""Initialize the database manager.
|
||||
|
||||
@@ -244,15 +245,16 @@ class DatabaseManager:
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
"""
|
||||
self._db_url = db_url
|
||||
if query_class is not None:
|
||||
self.Query = query_class
|
||||
if base is not None:
|
||||
self._base = base
|
||||
if not hasattr(base, "query"):
|
||||
if not hasattr(base, "query") or override_query_class:
|
||||
base.query = _QueryObject(self)
|
||||
if not getattr(base, "query_class", None):
|
||||
if not getattr(base, "query_class", None) or override_query_class:
|
||||
base.query_class = self.Query
|
||||
self._engine = create_engine(db_url, **(engine_args or {}))
|
||||
session_factory = sessionmaker(bind=self._engine)
|
||||
@@ -299,6 +301,59 @@ class DatabaseManager:
|
||||
def create_all(self):
|
||||
self.Model.metadata.create_all(self._engine)
|
||||
|
||||
@staticmethod
|
||||
def build_from(
|
||||
db_url_or_db: Union[str, URL, DatabaseManager],
|
||||
engine_args: Optional[Dict] = None,
|
||||
base: Optional[DeclarativeMeta] = None,
|
||||
query_class=BaseQuery,
|
||||
override_query_class: Optional[bool] = False,
|
||||
) -> DatabaseManager:
|
||||
"""Build the database manager from the db_url_or_db.
|
||||
|
||||
Examples:
|
||||
|
||||
Build from the database url.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.storage.metadata import DatabaseManager
|
||||
from sqlalchemy import Column, Integer, String
|
||||
db = DatabaseManager.build_from("sqlite:///:memory:")
|
||||
class User(db.Model):
|
||||
__tablename__ = "user"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(50))
|
||||
fullname = Column(String(50))
|
||||
db.create_all()
|
||||
with db.session() as session:
|
||||
session.add(User(name="test", fullname="test"))
|
||||
session.commit()
|
||||
print(User.query.filter(User.name == "test").all())
|
||||
|
||||
Args:
|
||||
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
|
||||
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
|
||||
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
|
||||
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
|
||||
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DatabaseManager: The database manager.
|
||||
"""
|
||||
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
|
||||
db_manager = DatabaseManager()
|
||||
db_manager.init_db(
|
||||
db_url_or_db, engine_args, base, query_class, override_query_class
|
||||
)
|
||||
return db_manager
|
||||
elif isinstance(db_url_or_db, DatabaseManager):
|
||||
return db_url_or_db
|
||||
else:
|
||||
raise ValueError(
|
||||
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
|
||||
)
|
||||
|
||||
|
||||
db = DatabaseManager()
|
||||
"""The global database manager.
|
||||
@@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
|
||||
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
|
||||
|
||||
_db_manager: DatabaseManager = db_manager
|
||||
|
||||
@classmethod
|
||||
def set_db_manager(cls, db_manager: DatabaseManager):
|
||||
# TODO: It is hard to replace to user DB Connection
|
||||
cls._db_manager = db_manager
|
||||
|
||||
@classmethod
|
||||
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
|
||||
"""Get a record by its primary key identifier."""
|
||||
return db_manager._session().get(cls, ident)
|
||||
return cls._db_manager._session().get(cls, ident)
|
||||
|
||||
def save(self: T, commit: Optional[bool] = True) -> T:
|
||||
"""Save the record."""
|
||||
session = db_manager._session()
|
||||
session = self._db_manager._session()
|
||||
session.add(self)
|
||||
if commit:
|
||||
session.commit()
|
||||
@@ -390,7 +452,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
|
||||
|
||||
def delete(self: T, commit: Optional[bool] = True) -> None:
|
||||
"""Remove the record from the database."""
|
||||
session = db_manager._session()
|
||||
session = self._db_manager._session()
|
||||
session.delete(self)
|
||||
return commit and session.commit()
|
||||
|
||||
|
Reference in New Issue
Block a user