feat(core): More AWEL operators and new prompt manager API (#972)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Fangyin Cheng
2023-12-25 20:03:22 +08:00
committed by GitHub
parent 048fb6c402
commit 69fb97e508
46 changed files with 2556 additions and 294 deletions

View File

@@ -236,6 +236,7 @@ class DatabaseManager:
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
):
"""Initialize the database manager.
@@ -244,15 +245,16 @@ class DatabaseManager:
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
"""
self._db_url = db_url
if query_class is not None:
self.Query = query_class
if base is not None:
self._base = base
if not hasattr(base, "query"):
if not hasattr(base, "query") or override_query_class:
base.query = _QueryObject(self)
if not getattr(base, "query_class", None):
if not getattr(base, "query_class", None) or override_query_class:
base.query_class = self.Query
self._engine = create_engine(db_url, **(engine_args or {}))
session_factory = sessionmaker(bind=self._engine)
@@ -299,6 +301,59 @@ class DatabaseManager:
def create_all(self):
self.Model.metadata.create_all(self._engine)
@staticmethod
def build_from(
db_url_or_db: Union[str, URL, DatabaseManager],
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
override_query_class: Optional[bool] = False,
) -> DatabaseManager:
"""Build the database manager from the db_url_or_db.
Examples:
Build from the database url.
.. code-block:: python
from dbgpt.storage.metadata import DatabaseManager
from sqlalchemy import Column, Integer, String
db = DatabaseManager.build_from("sqlite:///:memory:")
class User(db.Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
fullname = Column(String(50))
db.create_all()
with db.session() as session:
session.add(User(name="test", fullname="test"))
session.commit()
print(User.query.filter(User.name == "test").all())
Args:
db_url_or_db (Union[str, URL, DatabaseManager]): The database url or the database manager.
engine_args (Optional[Dict], optional): The engine arguments. Defaults to None.
base (Optional[DeclarativeMeta]): The base class. Defaults to None.
query_class (BaseQuery, optional): The query class. Defaults to BaseQuery.
override_query_class (Optional[bool], optional): Whether to override the query class. Defaults to False.
Returns:
DatabaseManager: The database manager.
"""
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(
db_url_or_db, engine_args, base, query_class, override_query_class
)
return db_manager
elif isinstance(db_url_or_db, DatabaseManager):
return db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
db = DatabaseManager()
"""The global database manager.
@@ -375,14 +430,21 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
class CRUDMixin(BaseCRUDMixin[T], Generic[T]):
"""Mixin that adds convenience methods for CRUD (create, read, update, delete)"""
_db_manager: DatabaseManager = db_manager
@classmethod
def set_db_manager(cls, db_manager: DatabaseManager):
# TODO: It is hard to replace to user DB Connection
cls._db_manager = db_manager
@classmethod
def get(cls: Type[T], ident: _PKIdentityArgument) -> Optional[T]:
"""Get a record by its primary key identifier."""
return db_manager._session().get(cls, ident)
return cls._db_manager._session().get(cls, ident)
def save(self: T, commit: Optional[bool] = True) -> T:
"""Save the record."""
session = db_manager._session()
session = self._db_manager._session()
session.add(self)
if commit:
session.commit()
@@ -390,7 +452,7 @@ def create_model(db_manager: DatabaseManager) -> Type[BaseModel[T]]:
def delete(self: T, commit: Optional[bool] = True) -> None:
"""Remove the record from the database."""
session = db_manager._session()
session = self._db_manager._session()
session.delete(self)
return commit and session.commit()

View File

@@ -34,16 +34,9 @@ class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
query_class=BaseQuery,
):
super().__init__(serializer=serializer, adapter=adapter)
if isinstance(db_url_or_db, str) or isinstance(db_url_or_db, URL):
db_manager = DatabaseManager()
db_manager.init_db(db_url_or_db, engine_args, base, query_class)
self.db_manager = db_manager
elif isinstance(db_url_or_db, DatabaseManager):
self.db_manager = db_url_or_db
else:
raise ValueError(
f"db_url_or_db should be either url or a DatabaseManager, got {type(db_url_or_db)}"
)
self.db_manager = DatabaseManager.build_from(
db_url_or_db, engine_args, base, query_class
)
self._model_class = model_class
@contextmanager

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import pytest
import tempfile
from typing import Type
from dbgpt.storage.metadata.db_manager import (
DatabaseManager,
@@ -103,7 +104,6 @@ def test_pagination_query(db: DatabaseManager, Model: Type[BaseModel]):
db.create_all()
# 添加数据
with db.session() as session:
for i in range(30):
user = User(name=f"User {i}")
@@ -127,3 +127,29 @@ def test_invalid_pagination(db: DatabaseManager, Model: Type[BaseModel]):
User.query.paginate_query(page=0, per_page=10)
with pytest.raises(ValueError):
User.query.paginate_query(page=1, per_page=-1)
def test_set_model_db_manager(db: DatabaseManager, Model: Type[BaseModel]):
assert db.metadata.tables == {}
class User(Model):
__tablename__ = "user"
id = Column(Integer, primary_key=True)
name = Column(String(50))
with tempfile.NamedTemporaryFile(delete=True) as db_file:
filename = db_file.name
new_db = DatabaseManager.build_from(
f"sqlite:///{filename}", base=Model, override_query_class=True
)
Model.set_db_manager(new_db)
new_db.create_all()
db.create_all()
assert list(new_db.metadata.tables.keys())[0] == "user"
User.create(**{"name": "John Doe"})
with new_db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is not None
with db.session() as session:
assert session.query(User).filter_by(name="John Doe").first() is None
assert len(User.query.all()) == 1
assert User.query.filter(User.name == "John Doe").first().name == "John Doe"