feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -236,6 +236,7 @@ class DatabaseManager:
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
):
"""Initialize the database manager.
@@ -244,15 +245,16 @@ class DatabaseManager:
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
"""
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query"):
if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self)
if not getattr(base, "query_class", None):
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
@@ -299,6 +301,59 @@ class DatabaseManager:
def create_all(self):
self.Model.metadata.create_all(self._engine)
@staticmethod
def build_from(
db_url_or_db: Union[str, URL, DatabaseManager],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
) -> DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
from sqlalchemy import Column, Integer, String
db = DatabaseManager.build_from("sqlite:///:memory:")
class User(db.Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
db.create_all()
with db.session() as session:
session.add(User(name="test", fullname="test"))
session.commit()
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
)
return db_manager
elif isinstance(db_url_or_db, DatabaseManager):
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
db = DatabaseManager()
"""The global database manager.
@@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db_manager(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return db_manager._session().get(cls, ident)
return cls._db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = db_manager._session()
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
@@ -390,7 +452,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = db_manager._session()
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()