mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 12:25:37 +00:00
xfailing some sql tests that do not currently work on sqlalchemy v1 #22207 was very much not sqlalchemy v1 compatible. Moving forward, implementations should be compatible with both to pass CI
296 lines
10 KiB
Python
296 lines
10 KiB
Python
import contextlib
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
AsyncIterator,
|
|
Dict,
|
|
Generator,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from langchain_core.stores import BaseStore
|
|
from sqlalchemy import (
|
|
LargeBinary,
|
|
Text,
|
|
and_,
|
|
create_engine,
|
|
delete,
|
|
select,
|
|
)
|
|
from sqlalchemy.engine.base import Engine
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.orm import (
|
|
Mapped,
|
|
Session,
|
|
declarative_base,
|
|
sessionmaker,
|
|
)
|
|
|
|
try:
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
except ImportError:
|
|
# dummy for sqlalchemy < 2
|
|
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore
|
|
|
|
Base = declarative_base()
|
|
|
|
try:
|
|
from sqlalchemy.orm import mapped_column
|
|
|
|
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
|
|
"""Table used to save values."""
|
|
|
|
# ATTENTION:
|
|
# Prior to modifying this table, please determine whether
|
|
# we should create migrations for this table to make sure
|
|
# users do not experience data loss.
|
|
__tablename__ = "langchain_key_value_stores"
|
|
|
|
namespace: Mapped[str] = mapped_column(
|
|
primary_key=True, index=True, nullable=False
|
|
)
|
|
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
|
|
value = mapped_column(LargeBinary, index=False, nullable=False)
|
|
|
|
except ImportError:
|
|
# dummy for sqlalchemy < 2
|
|
from sqlalchemy import Column
|
|
|
|
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc,no-redef]
|
|
"""Table used to save values."""
|
|
|
|
# ATTENTION:
|
|
# Prior to modifying this table, please determine whether
|
|
# we should create migrations for this table to make sure
|
|
# users do not experience data loss.
|
|
__tablename__ = "langchain_key_value_stores"
|
|
|
|
namespace = Column(Text(), primary_key=True, index=True, nullable=False)
|
|
key = Column(Text(), primary_key=True, index=True, nullable=False)
|
|
value = Column(LargeBinary, index=False, nullable=False)
|
|
|
|
|
|
def items_equal(x: Any, y: Any) -> bool:
|
|
return x == y
|
|
|
|
|
|
# This is a fix of original SQLStore.
|
|
# This can will be removed when a PR will be merged.
|
|
class SQLStore(BaseStore[str, bytes]):
|
|
"""BaseStore interface that works on an SQL database.
|
|
|
|
Examples:
|
|
Create a SQLStore instance and perform operations on it:
|
|
|
|
.. code-block:: python
|
|
|
|
from langchain_rag.storage import SQLStore
|
|
|
|
# Instantiate the SQLStore with the root path
|
|
sql_store = SQLStore(namespace="test", db_url="sqlite://:memory:")
|
|
|
|
# Set values for keys
|
|
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
|
|
|
|
# Get values for keys
|
|
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
|
|
|
# Delete keys
|
|
sql_store.mdelete(["key1"])
|
|
|
|
# Iterate over keys
|
|
for key in sql_store.yield_keys():
|
|
print(key)
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
namespace: str,
|
|
db_url: Optional[Union[str, Path]] = None,
|
|
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
|
engine_kwargs: Optional[Dict[str, Any]] = None,
|
|
async_mode: Optional[bool] = None,
|
|
):
|
|
if db_url is None and engine is None:
|
|
raise ValueError("Must specify either db_url or engine")
|
|
|
|
if db_url is not None and engine is not None:
|
|
raise ValueError("Must specify either db_url or engine, not both")
|
|
|
|
_engine: Union[Engine, AsyncEngine]
|
|
if db_url:
|
|
if async_mode is None:
|
|
async_mode = False
|
|
if async_mode:
|
|
_engine = create_async_engine(
|
|
url=str(db_url),
|
|
**(engine_kwargs or {}),
|
|
)
|
|
else:
|
|
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
|
|
elif engine:
|
|
_engine = engine
|
|
|
|
else:
|
|
raise AssertionError("Something went wrong with configuration of engine.")
|
|
|
|
_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
|
if isinstance(_engine, AsyncEngine):
|
|
self.async_mode = True
|
|
_session_maker = async_sessionmaker(bind=_engine)
|
|
else:
|
|
self.async_mode = False
|
|
_session_maker = sessionmaker(bind=_engine)
|
|
|
|
self.engine = _engine
|
|
self.dialect = _engine.dialect.name
|
|
self.session_maker = _session_maker
|
|
self.namespace = namespace
|
|
|
|
def create_schema(self) -> None:
|
|
Base.metadata.create_all(self.engine) # problem in sqlalchemy v1
|
|
# sqlalchemy.exc.CompileError: (in table 'langchain_key_value_stores',
|
|
# column 'namespace'): Can't generate DDL for NullType(); did you forget
|
|
# to specify a type on this Column?
|
|
|
|
async def acreate_schema(self) -> None:
|
|
assert isinstance(self.engine, AsyncEngine)
|
|
async with self.engine.begin() as session:
|
|
await session.run_sync(Base.metadata.create_all)
|
|
|
|
def drop(self) -> None:
|
|
Base.metadata.drop_all(bind=self.engine.connect())
|
|
|
|
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
|
assert isinstance(self.engine, AsyncEngine)
|
|
result: Dict[str, bytes] = {}
|
|
async with self._make_async_session() as session:
|
|
stmt = select(LangchainKeyValueStores).filter(
|
|
and_(
|
|
LangchainKeyValueStores.key.in_(keys),
|
|
LangchainKeyValueStores.namespace == self.namespace,
|
|
)
|
|
)
|
|
for v in await session.scalars(stmt):
|
|
result[v.key] = v.value
|
|
return [result.get(key) for key in keys]
|
|
|
|
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
|
result = {}
|
|
|
|
with self._make_sync_session() as session:
|
|
stmt = select(LangchainKeyValueStores).filter(
|
|
and_(
|
|
LangchainKeyValueStores.key.in_(keys),
|
|
LangchainKeyValueStores.namespace == self.namespace,
|
|
)
|
|
)
|
|
for v in session.scalars(stmt):
|
|
result[v.key] = v.value
|
|
return [result.get(key) for key in keys]
|
|
|
|
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
|
async with self._make_async_session() as session:
|
|
await self._amdelete([key for key, _ in key_value_pairs], session)
|
|
session.add_all(
|
|
[
|
|
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
|
for k, v in key_value_pairs
|
|
]
|
|
)
|
|
await session.commit()
|
|
|
|
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
|
values: Dict[str, bytes] = dict(key_value_pairs)
|
|
with self._make_sync_session() as session:
|
|
self._mdelete(list(values.keys()), session)
|
|
session.add_all(
|
|
[
|
|
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
|
|
for k, v in values.items()
|
|
]
|
|
)
|
|
session.commit()
|
|
|
|
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
|
stmt = delete(LangchainKeyValueStores).filter(
|
|
and_(
|
|
LangchainKeyValueStores.key.in_(keys),
|
|
LangchainKeyValueStores.namespace == self.namespace,
|
|
)
|
|
)
|
|
session.execute(stmt)
|
|
|
|
async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
|
|
stmt = delete(LangchainKeyValueStores).filter(
|
|
and_(
|
|
LangchainKeyValueStores.key.in_(keys),
|
|
LangchainKeyValueStores.namespace == self.namespace,
|
|
)
|
|
)
|
|
await session.execute(stmt)
|
|
|
|
def mdelete(self, keys: Sequence[str]) -> None:
|
|
with self._make_sync_session() as session:
|
|
self._mdelete(keys, session)
|
|
session.commit()
|
|
|
|
async def amdelete(self, keys: Sequence[str]) -> None:
|
|
async with self._make_async_session() as session:
|
|
await self._amdelete(keys, session)
|
|
await session.commit()
|
|
|
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
|
with self._make_sync_session() as session:
|
|
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
|
|
LangchainKeyValueStores.namespace == self.namespace
|
|
):
|
|
if str(v.key).startswith(prefix or ""):
|
|
yield str(v.key)
|
|
session.close()
|
|
|
|
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
|
async with self._make_async_session() as session:
|
|
stmt = select(LangchainKeyValueStores).filter(
|
|
LangchainKeyValueStores.namespace == self.namespace
|
|
)
|
|
for v in await session.scalars(stmt):
|
|
if str(v.key).startswith(prefix or ""):
|
|
yield str(v.key)
|
|
await session.close()
|
|
|
|
@contextlib.contextmanager
|
|
def _make_sync_session(self) -> Generator[Session, None, None]:
|
|
"""Make an async session."""
|
|
if self.async_mode:
|
|
raise ValueError(
|
|
"Attempting to use a sync method in when async mode is turned on. "
|
|
"Please use the corresponding async method instead."
|
|
)
|
|
with cast(Session, self.session_maker()) as session:
|
|
yield cast(Session, session)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""Make an async session."""
|
|
if not self.async_mode:
|
|
raise ValueError(
|
|
"Attempting to use an async method in when sync mode is turned on. "
|
|
"Please use the corresponding async method instead."
|
|
)
|
|
async with cast(AsyncSession, self.session_maker()) as session:
|
|
yield cast(AsyncSession, session)
|