mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
145 lines
5.4 KiB
Python
145 lines
5.4 KiB
Python
"""Database storage implementation using SQLAlchemy."""
|
|
from contextlib import contextmanager
|
|
from typing import Dict, Iterator, List, Optional, Type, Union
|
|
|
|
from sqlalchemy import URL
|
|
from sqlalchemy.orm import DeclarativeMeta, Session
|
|
|
|
from dbgpt.core import Serializer
|
|
from dbgpt.core.interface.storage import (
|
|
QuerySpec,
|
|
ResourceIdentifier,
|
|
StorageInterface,
|
|
StorageItemAdapter,
|
|
T,
|
|
)
|
|
|
|
from .db_manager import BaseModel, BaseQuery, DatabaseManager
|
|
|
|
|
|
def _copy_public_properties(src: BaseModel, dest: BaseModel):
|
|
"""Copy public properties from src to dest."""
|
|
for column in src.__table__.columns: # type: ignore
|
|
if column.name != "id":
|
|
value = getattr(src, column.name)
|
|
if value is not None:
|
|
setattr(dest, column.name, value)
|
|
|
|
|
|
class SQLAlchemyStorage(StorageInterface[T, BaseModel]):
|
|
"""Database storage implementation using SQLAlchemy."""
|
|
|
|
def __init__(
|
|
self,
|
|
db_url_or_db: Union[str, URL, DatabaseManager],
|
|
model_class: Type[BaseModel],
|
|
adapter: StorageItemAdapter[T, BaseModel],
|
|
serializer: Optional[Serializer] = None,
|
|
engine_args: Optional[Dict] = None,
|
|
base: Optional[DeclarativeMeta] = None,
|
|
query_class=BaseQuery,
|
|
):
|
|
"""Create a SQLAlchemyStorage instance."""
|
|
super().__init__(serializer=serializer, adapter=adapter)
|
|
self.db_manager = DatabaseManager.build_from(
|
|
db_url_or_db, engine_args, base, query_class
|
|
)
|
|
self._model_class = model_class
|
|
|
|
@contextmanager
|
|
def session(self) -> Iterator[Session]:
|
|
"""Return a session."""
|
|
with self.db_manager.session() as session:
|
|
yield session
|
|
|
|
def save(self, data: T) -> None:
|
|
"""Save data to the storage."""
|
|
with self.session() as session:
|
|
model_instance = self.adapter.to_storage_format(data)
|
|
session.add(model_instance)
|
|
|
|
def update(self, data: T) -> None:
|
|
"""Update data in the storage."""
|
|
with self.session() as session:
|
|
query = self.adapter.get_query_for_identifier(
|
|
self._model_class, data.identifier, session=session
|
|
)
|
|
exist_model_instance = query.with_session(session).first()
|
|
if exist_model_instance:
|
|
_copy_public_properties(
|
|
self.adapter.to_storage_format(data), exist_model_instance
|
|
)
|
|
session.merge(exist_model_instance)
|
|
return
|
|
|
|
def save_or_update(self, data: T) -> None:
|
|
"""Save or update data in the storage."""
|
|
with self.session() as session:
|
|
query = self.adapter.get_query_for_identifier(
|
|
self._model_class, data.identifier, session=session
|
|
)
|
|
model_instance = query.with_session(session).first()
|
|
if model_instance:
|
|
new_instance = self.adapter.to_storage_format(data)
|
|
_copy_public_properties(new_instance, model_instance)
|
|
session.merge(model_instance)
|
|
return
|
|
self.save(data)
|
|
|
|
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
|
"""Load data by identifier from the storage."""
|
|
with self.session() as session:
|
|
query = self.adapter.get_query_for_identifier(
|
|
self._model_class, resource_id, session=session
|
|
)
|
|
model_instance = query.with_session(session).first()
|
|
if model_instance:
|
|
return self.adapter.from_storage_format(model_instance)
|
|
return None
|
|
|
|
def delete(self, resource_id: ResourceIdentifier) -> None:
|
|
"""Delete data by identifier from the storage."""
|
|
with self.session() as session:
|
|
query = self.adapter.get_query_for_identifier(
|
|
self._model_class, resource_id, session=session
|
|
)
|
|
model_instance = query.with_session(session).first()
|
|
if model_instance:
|
|
session.delete(model_instance)
|
|
|
|
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
|
"""Query data from the storage.
|
|
|
|
Args:
|
|
spec (QuerySpec): The query specification
|
|
cls (Type[T]): The type of the data
|
|
"""
|
|
with self.session() as session:
|
|
query = session.query(self._model_class)
|
|
for key, value in spec.conditions.items():
|
|
if value is not None:
|
|
query = query.filter(getattr(self._model_class, key) == value)
|
|
if spec.limit is not None:
|
|
query = query.limit(spec.limit)
|
|
if spec.offset is not None:
|
|
query = query.offset(spec.offset)
|
|
model_instances = query.all()
|
|
return [
|
|
self.adapter.from_storage_format(instance)
|
|
for instance in model_instances
|
|
]
|
|
|
|
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
|
"""Count the number of data in the storage.
|
|
|
|
Args:
|
|
spec (QuerySpec): The query specification
|
|
cls (Type[T]): The type of the data
|
|
"""
|
|
with self.session() as session:
|
|
query = session.query(self._model_class)
|
|
for key, value in spec.conditions.items():
|
|
if value is not None:
|
|
query = query.filter(getattr(self._model_class, key) == value)
|
|
return query.count()
|