DB-GPT/dbgpt/storage/metadata/db_storage.py
2024-03-15 15:42:46 +08:00

145 lines
5.4 KiB
Python

"""Database storage implementation using SQLAlchemy."""
from contextlib import contextmanager
from typing import Dict, Iterator, List, Optional, Type, Union
from sqlalchemy import URL
from sqlalchemy.orm import DeclarativeMeta, Session
from dbgpt.core import Serializer
from dbgpt.core.interface.storage import (
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItemAdapter,
T,
)
from .db_manager import BaseModel, BaseQuery, DatabaseManager
def _copy_public_properties(src: BaseModel, dest: BaseModel):
"""Copy public properties from src to dest."""
for column in src.__table__.columns: # type: ignore
if column.name != "id":
value = getattr(src, column.name)
if value is not None:
setattr(dest, column.name, value)
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
"""Database storage implementation using SQLAlchemy."""
def __init__(
self,
db_url_or_db: Union[str, URL, DatabaseManager],
model_class: Type[BaseModel],
adapter: StorageItemAdapter[T, BaseModel],
serializer: Optional[Serializer] = None,
engine_args: Optional[Dict] = None,
base: Optional[DeclarativeMeta] = None,
query_class=BaseQuery,
):
"""Create a SQLAlchemyStorage instance."""
super().__init__(serializer=serializer, adapter=adapter)
self.db_manager = DatabaseManager.build_from(
db_url_or_db, engine_args, base, query_class
)
self._model_class = model_class
@contextmanager
def session(self) -> Iterator[Session]:
"""Return a session."""
with self.db_manager.session() as session:
yield session
def save(self, data: T) -> None:
"""Save data to the storage."""
with self.session() as session:
model_instance = self.adapter.to_storage_format(data)
session.add(model_instance)
def update(self, data: T) -> None:
"""Update data in the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
)
exist_model_instance = query.with_session(session).first()
if exist_model_instance:
_copy_public_properties(
self.adapter.to_storage_format(data), exist_model_instance
)
session.merge(exist_model_instance)
return
def save_or_update(self, data: T) -> None:
"""Save or update data in the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, data.identifier, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
new_instance = self.adapter.to_storage_format(data)
_copy_public_properties(new_instance, model_instance)
session.merge(model_instance)
return
self.save(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
"""Load data by identifier from the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
return self.adapter.from_storage_format(model_instance)
return None
def delete(self, resource_id: ResourceIdentifier) -> None:
"""Delete data by identifier from the storage."""
with self.session() as session:
query = self.adapter.get_query_for_identifier(
self._model_class, resource_id, session=session
)
model_instance = query.with_session(session).first()
if model_instance:
session.delete(model_instance)
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
if value is not None:
query = query.filter(getattr(self._model_class, key) == value)
if spec.limit is not None:
query = query.limit(spec.limit)
if spec.offset is not None:
query = query.offset(spec.offset)
model_instances = query.all()
return [
self.adapter.from_storage_format(instance)
for instance in model_instances
]
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data in the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
"""
with self.session() as session:
query = session.query(self._model_class)
for key, value in spec.conditions.items():
if value is not None:
query = query.filter(getattr(self._model_class, key) == value)
return query.count()