mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
community: revert SQL Stores (#16912)
This reverts commit cfc225ecb3
.
https://github.com/langchain-ai/langchain/pull/15909#issuecomment-1922418097
These will have existed in langchain-community 0.0.16 and 0.0.17.
This commit is contained in:
parent
f7c709b40e
commit
b1a847366c
@ -1,186 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "raw",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"---\n",
|
|
||||||
"sidebar_label: SQL\n",
|
|
||||||
"---"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# SQLStore\n",
|
|
||||||
"\n",
|
|
||||||
"The `SQLStrStore` and `SQLDocStore` implement remote data access and persistence to store strings or LangChain documents in your SQL instance."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"['value1', 'value2']\n",
|
|
||||||
"['key2']\n",
|
|
||||||
"['key2']\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from langchain_community.storage import SQLStrStore\n",
|
|
||||||
"\n",
|
|
||||||
"# simple example using an SQLStrStore to store strings\n",
|
|
||||||
"# same as you would use in \"InMemoryStore\" but using SQL persistence\n",
|
|
||||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
|
||||||
"COLLECTION_NAME = \"test_collection\"\n",
|
|
||||||
"\n",
|
|
||||||
"store = SQLStrStore(\n",
|
|
||||||
" collection_name=COLLECTION_NAME,\n",
|
|
||||||
" connection_string=CONNECTION_STRING,\n",
|
|
||||||
")\n",
|
|
||||||
"store.mset([(\"key1\", \"value1\"), (\"key2\", \"value2\")])\n",
|
|
||||||
"print(store.mget([\"key1\", \"key2\"]))\n",
|
|
||||||
"# ['value1', 'value2']\n",
|
|
||||||
"store.mdelete([\"key1\"])\n",
|
|
||||||
"print(list(store.yield_keys()))\n",
|
|
||||||
"# ['key2']\n",
|
|
||||||
"print(list(store.yield_keys(prefix=\"k\")))\n",
|
|
||||||
"# ['key2']\n",
|
|
||||||
"# delete the COLLECTION_NAME collection"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Integration with ParentRetriever and PGVector\n",
|
|
||||||
"\n",
|
|
||||||
"When using PGVector, you already have a SQL instance running. Here is a convenient way of using this instance to store documents associated to vectors. "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Prepare the PGVector vectorestore with something like this:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain_community.vectorstores import PGVector\n",
|
|
||||||
"from langchain_openai import OpenAIEmbeddings"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"embeddings = OpenAIEmbeddings()\n",
|
|
||||||
"vector_db = PGVector.from_existing_index(\n",
|
|
||||||
" embedding=embeddings,\n",
|
|
||||||
" collection_name=COLLECTION_NAME,\n",
|
|
||||||
" connection_string=CONNECTION_STRING,\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Then create the parent retiever using `SQLDocStore` to persist the documents"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain.document_loaders import TextLoader\n",
|
|
||||||
"from langchain.retrievers import ParentDocumentRetriever\n",
|
|
||||||
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
|
||||||
"from langchain_community.storage import SQLDocStore\n",
|
|
||||||
"\n",
|
|
||||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
|
||||||
"COLLECTION_NAME = \"state_of_the_union_test\"\n",
|
|
||||||
"docstore = SQLDocStore(\n",
|
|
||||||
" collection_name=COLLECTION_NAME,\n",
|
|
||||||
" connection_string=CONNECTION_STRING,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"loader = TextLoader(\"./state_of_the_union.txt\")\n",
|
|
||||||
"documents = loader.load()\n",
|
|
||||||
"\n",
|
|
||||||
"parent_splitter = RecursiveCharacterTextSplitter(chunk_size=400)\n",
|
|
||||||
"child_splitter = RecursiveCharacterTextSplitter(chunk_size=50)\n",
|
|
||||||
"\n",
|
|
||||||
"retriever = ParentDocumentRetriever(\n",
|
|
||||||
" vectorstore=vector_db,\n",
|
|
||||||
" docstore=docstore,\n",
|
|
||||||
" child_splitter=child_splitter,\n",
|
|
||||||
" parent_splitter=parent_splitter,\n",
|
|
||||||
")\n",
|
|
||||||
"retriever.add_documents(documents)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Delete a collection"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain_community.storage import SQLStrStore\n",
|
|
||||||
"\n",
|
|
||||||
"# delete the COLLECTION_NAME collection\n",
|
|
||||||
"CONNECTION_STRING = \"postgresql+psycopg2://user:pass@localhost:5432/db\"\n",
|
|
||||||
"COLLECTION_NAME = \"test_collection\"\n",
|
|
||||||
"store = SQLStrStore(\n",
|
|
||||||
" collection_name=COLLECTION_NAME,\n",
|
|
||||||
" connection_string=CONNECTION_STRING,\n",
|
|
||||||
")\n",
|
|
||||||
"store.delete_collection()"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.1"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
@ -11,10 +11,6 @@ from langchain_community.storage.astradb import (
|
|||||||
AstraDBStore,
|
AstraDBStore,
|
||||||
)
|
)
|
||||||
from langchain_community.storage.redis import RedisStore
|
from langchain_community.storage.redis import RedisStore
|
||||||
from langchain_community.storage.sql import (
|
|
||||||
SQLDocStore,
|
|
||||||
SQLStrStore,
|
|
||||||
)
|
|
||||||
from langchain_community.storage.upstash_redis import (
|
from langchain_community.storage.upstash_redis import (
|
||||||
UpstashRedisByteStore,
|
UpstashRedisByteStore,
|
||||||
UpstashRedisStore,
|
UpstashRedisStore,
|
||||||
@ -26,6 +22,4 @@ __all__ = [
|
|||||||
"RedisStore",
|
"RedisStore",
|
||||||
"UpstashRedisByteStore",
|
"UpstashRedisByteStore",
|
||||||
"UpstashRedisStore",
|
"UpstashRedisStore",
|
||||||
"SQLDocStore",
|
|
||||||
"SQLStrStore",
|
|
||||||
]
|
]
|
||||||
|
@ -1,345 +0,0 @@
|
|||||||
"""SQL storage that persists data in a SQL database
|
|
||||||
and supports data isolation using collections."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar
|
|
||||||
|
|
||||||
import sqlalchemy
|
|
||||||
from sqlalchemy import JSON, UUID
|
|
||||||
from sqlalchemy.orm import Session, relationship
|
|
||||||
|
|
||||||
try:
|
|
||||||
from sqlalchemy.orm import declarative_base
|
|
||||||
except ImportError:
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.load import Serializable, dumps, loads
|
|
||||||
from langchain_core.stores import BaseStore
|
|
||||||
|
|
||||||
V = TypeVar("V")
|
|
||||||
|
|
||||||
ITERATOR_WINDOW_SIZE = 1000
|
|
||||||
|
|
||||||
Base = declarative_base() # type: Any
|
|
||||||
|
|
||||||
|
|
||||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(Base):
|
|
||||||
"""Base model for the SQL stores."""
|
|
||||||
|
|
||||||
__abstract__ = True
|
|
||||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
|
||||||
|
|
||||||
|
|
||||||
_classes: Any = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_storage_stores() -> Any:
|
|
||||||
global _classes
|
|
||||||
if _classes is not None:
|
|
||||||
return _classes
|
|
||||||
|
|
||||||
class CollectionStore(BaseModel):
|
|
||||||
"""Collection store."""
|
|
||||||
|
|
||||||
__tablename__ = "langchain_storage_collection"
|
|
||||||
|
|
||||||
name = sqlalchemy.Column(sqlalchemy.String)
|
|
||||||
cmetadata = sqlalchemy.Column(JSON)
|
|
||||||
|
|
||||||
items = relationship(
|
|
||||||
"ItemStore",
|
|
||||||
back_populates="collection",
|
|
||||||
passive_deletes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_by_name(
|
|
||||||
cls, session: Session, name: str
|
|
||||||
) -> Optional["CollectionStore"]:
|
|
||||||
# type: ignore
|
|
||||||
return session.query(cls).filter(cls.name == name).first()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_or_create(
|
|
||||||
cls,
|
|
||||||
session: Session,
|
|
||||||
name: str,
|
|
||||||
cmetadata: Optional[dict] = None,
|
|
||||||
) -> Tuple["CollectionStore", bool]:
|
|
||||||
"""
|
|
||||||
Get or create a collection.
|
|
||||||
Returns [Collection, bool] where the bool is True if the collection was created.
|
|
||||||
""" # noqa: E501
|
|
||||||
created = False
|
|
||||||
collection = cls.get_by_name(session, name)
|
|
||||||
if collection:
|
|
||||||
return collection, created
|
|
||||||
|
|
||||||
collection = cls(name=name, cmetadata=cmetadata)
|
|
||||||
session.add(collection)
|
|
||||||
session.commit()
|
|
||||||
created = True
|
|
||||||
return collection, created
|
|
||||||
|
|
||||||
class ItemStore(BaseModel):
|
|
||||||
"""Item store."""
|
|
||||||
|
|
||||||
__tablename__ = "langchain_storage_items"
|
|
||||||
|
|
||||||
collection_id = sqlalchemy.Column(
|
|
||||||
UUID(as_uuid=True),
|
|
||||||
sqlalchemy.ForeignKey(
|
|
||||||
f"{CollectionStore.__tablename__}.uuid",
|
|
||||||
ondelete="CASCADE",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
collection = relationship(CollectionStore, back_populates="items")
|
|
||||||
|
|
||||||
content = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
|
||||||
|
|
||||||
# custom_id : any user defined id
|
|
||||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
|
||||||
|
|
||||||
_classes = (ItemStore, CollectionStore)
|
|
||||||
|
|
||||||
return _classes
|
|
||||||
|
|
||||||
|
|
||||||
class SQLBaseStore(BaseStore[str, V], Generic[V]):
|
|
||||||
"""SQL storage
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connection_string: SQL connection string that will be passed to SQLAlchemy.
|
|
||||||
collection_name: The name of the collection to use. (default: langchain)
|
|
||||||
NOTE: Collections are useful to isolate your data in a given a database.
|
|
||||||
This is not the name of the table, but the name of the collection.
|
|
||||||
The tables will be created when initializing the store (if not exists)
|
|
||||||
So, make sure the user has the right permissions to create tables.
|
|
||||||
pre_delete_collection: If True, will delete the collection if it exists.
|
|
||||||
(default: False). Useful for testing.
|
|
||||||
engine_args: SQLAlchemy's create engine arguments.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_community.storage import SQLDocStore
|
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
# example using an SQLDocStore to store Document objects for
|
|
||||||
# a ParentDocumentRetriever
|
|
||||||
CONNECTION_STRING = "postgresql+psycopg2://user:pass@localhost:5432/db"
|
|
||||||
COLLECTION_NAME = "state_of_the_union_test"
|
|
||||||
docstore = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
)
|
|
||||||
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
|
||||||
vectorstore = ...
|
|
||||||
|
|
||||||
retriever = ParentDocumentRetriever(
|
|
||||||
vectorstore=vectorstore,
|
|
||||||
docstore=docstore,
|
|
||||||
child_splitter=child_splitter,
|
|
||||||
)
|
|
||||||
|
|
||||||
# example using an SQLStrStore to store strings
|
|
||||||
# same example as in "InMemoryStore" but using SQL persistence
|
|
||||||
store = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
)
|
|
||||||
store.mset([('key1', 'value1'), ('key2', 'value2')])
|
|
||||||
store.mget(['key1', 'key2'])
|
|
||||||
# ['value1', 'value2']
|
|
||||||
store.mdelete(['key1'])
|
|
||||||
list(store.yield_keys())
|
|
||||||
# ['key2']
|
|
||||||
list(store.yield_keys(prefix='k'))
|
|
||||||
# ['key2']
|
|
||||||
|
|
||||||
# delete the COLLECTION_NAME collection
|
|
||||||
docstore.delete_collection()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
connection_string: str,
|
|
||||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
|
||||||
collection_metadata: Optional[dict] = None,
|
|
||||||
pre_delete_collection: bool = False,
|
|
||||||
connection: Optional[sqlalchemy.engine.Connection] = None,
|
|
||||||
engine_args: Optional[dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.connection_string = connection_string
|
|
||||||
self.collection_name = collection_name
|
|
||||||
self.collection_metadata = collection_metadata
|
|
||||||
self.pre_delete_collection = pre_delete_collection
|
|
||||||
self.engine_args = engine_args or {}
|
|
||||||
# Create a connection if not provided, otherwise use the provided connection
|
|
||||||
self._conn = connection if connection else self.__connect()
|
|
||||||
self.__post_init__()
|
|
||||||
|
|
||||||
def __post_init__(
|
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the store."""
|
|
||||||
ItemStore, CollectionStore = _get_storage_stores()
|
|
||||||
self.CollectionStore = CollectionStore
|
|
||||||
self.ItemStore = ItemStore
|
|
||||||
self.__create_tables_if_not_exists()
|
|
||||||
self.__create_collection()
|
|
||||||
|
|
||||||
def __connect(self) -> sqlalchemy.engine.Connection:
|
|
||||||
engine = sqlalchemy.create_engine(self.connection_string, **self.engine_args)
|
|
||||||
conn = engine.connect()
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def __create_tables_if_not_exists(self) -> None:
|
|
||||||
with self._conn.begin():
|
|
||||||
Base.metadata.create_all(self._conn)
|
|
||||||
|
|
||||||
def __create_collection(self) -> None:
|
|
||||||
if self.pre_delete_collection:
|
|
||||||
self.delete_collection()
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
self.CollectionStore.get_or_create(
|
|
||||||
session, self.collection_name, cmetadata=self.collection_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
collection = self.__get_collection(session)
|
|
||||||
if not collection:
|
|
||||||
return
|
|
||||||
session.delete(collection)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def __get_collection(self, session: Session) -> Any:
|
|
||||||
return self.CollectionStore.get_by_name(session, self.collection_name)
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
if self._conn:
|
|
||||||
self._conn.close()
|
|
||||||
|
|
||||||
def __serialize_value(self, obj: V) -> str:
|
|
||||||
if isinstance(obj, Serializable):
|
|
||||||
return dumps(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __deserialize_value(self, obj: V) -> str:
|
|
||||||
try:
|
|
||||||
return loads(obj)
|
|
||||||
except Exception:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
|
||||||
"""Get the values associated with the given keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keys (Sequence[str]): A sequence of keys.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A sequence of optional values associated with the keys.
|
|
||||||
If a key is not found, the corresponding value will be None.
|
|
||||||
"""
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
collection = self.__get_collection(session)
|
|
||||||
|
|
||||||
items = (
|
|
||||||
session.query(self.ItemStore.content, self.ItemStore.custom_id)
|
|
||||||
.where(
|
|
||||||
sqlalchemy.and_(
|
|
||||||
self.ItemStore.custom_id.in_(keys),
|
|
||||||
self.ItemStore.collection_id == (collection.uuid),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
ordered_values = {key: None for key in keys}
|
|
||||||
for item in items:
|
|
||||||
v = item[0]
|
|
||||||
val = self.__deserialize_value(v) if v is not None else v
|
|
||||||
k = item[1]
|
|
||||||
ordered_values[k] = val
|
|
||||||
|
|
||||||
return [ordered_values[key] for key in keys]
|
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
|
||||||
"""Set the values for the given keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
collection = self.__get_collection(session)
|
|
||||||
if not collection:
|
|
||||||
raise ValueError("Collection not found")
|
|
||||||
for id, item in key_value_pairs:
|
|
||||||
content = self.__serialize_value(item)
|
|
||||||
item_store = self.ItemStore(
|
|
||||||
content=content,
|
|
||||||
custom_id=id,
|
|
||||||
collection_id=collection.uuid,
|
|
||||||
)
|
|
||||||
session.add(item_store)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def mdelete(self, keys: Sequence[str]) -> None:
|
|
||||||
"""Delete the given keys and their associated values.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keys (Sequence[str]): A sequence of keys to delete.
|
|
||||||
"""
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
collection = self.__get_collection(session)
|
|
||||||
if not collection:
|
|
||||||
raise ValueError("Collection not found")
|
|
||||||
if keys is not None:
|
|
||||||
stmt = sqlalchemy.delete(self.ItemStore).where(
|
|
||||||
sqlalchemy.and_(
|
|
||||||
self.ItemStore.custom_id.in_(keys),
|
|
||||||
self.ItemStore.collection_id == (collection.uuid),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.execute(stmt)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
|
||||||
"""Get an iterator over keys that match the given prefix.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefix (str, optional): The prefix to match. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Iterator[str]: An iterator over keys that match the given prefix.
|
|
||||||
"""
|
|
||||||
with Session(self._conn) as session:
|
|
||||||
collection = self.__get_collection(session)
|
|
||||||
start = 0
|
|
||||||
while True:
|
|
||||||
stop = start + ITERATOR_WINDOW_SIZE
|
|
||||||
query = session.query(self.ItemStore.custom_id).where(
|
|
||||||
self.ItemStore.collection_id == (collection.uuid)
|
|
||||||
)
|
|
||||||
if prefix is not None:
|
|
||||||
query = query.filter(self.ItemStore.custom_id.startswith(prefix))
|
|
||||||
items = query.slice(start, stop).all()
|
|
||||||
|
|
||||||
if len(items) == 0:
|
|
||||||
break
|
|
||||||
for item in items:
|
|
||||||
yield item[0]
|
|
||||||
start += ITERATOR_WINDOW_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
SQLDocStore = SQLBaseStore[Document]
|
|
||||||
SQLStrStore = SQLBaseStore[str]
|
|
@ -1,228 +0,0 @@
|
|||||||
"""Implement integration tests for SQL storage."""
|
|
||||||
import os
|
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
from langchain_community.storage.sql import SQLDocStore, SQLStrStore
|
|
||||||
|
|
||||||
|
|
||||||
def connection_string_from_db_params() -> str:
|
|
||||||
"""Return connection string from database parameters."""
|
|
||||||
dbdriver = os.environ.get("TEST_SQL_DBDRIVER", "postgresql+psycopg2")
|
|
||||||
host = os.environ.get("TEST_SQL_HOST", "localhost")
|
|
||||||
port = int(os.environ.get("TEST_SQL_PORT", "5432"))
|
|
||||||
database = os.environ.get("TEST_SQL_DATABASE", "postgres")
|
|
||||||
user = os.environ.get("TEST_SQL_USER", "postgres")
|
|
||||||
password = os.environ.get("TEST_SQL_PASSWORD", "postgres")
|
|
||||||
return f"{dbdriver}://{user}:{password}@{host}:{port}/{database}"
|
|
||||||
|
|
||||||
|
|
||||||
CONNECTION_STRING = connection_string_from_db_params()
|
|
||||||
COLLECTION_NAME = "test_collection"
|
|
||||||
COLLECTION_NAME_2 = "test_collection_2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_str_store_mget() -> None:
|
|
||||||
store = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == ["value1", "value2"]
|
|
||||||
|
|
||||||
# Test non-existent key
|
|
||||||
non_existent_value = store.mget(["key3"])
|
|
||||||
assert non_existent_value == [None]
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_str_store_mset() -> None:
|
|
||||||
store = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == ["value1", "value2"]
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_str_store_mdelete() -> None:
|
|
||||||
store = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
|
||||||
|
|
||||||
store.mdelete(["key1"])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == [None, "value2"]
|
|
||||||
|
|
||||||
# Test deleting non-existent key
|
|
||||||
store.mdelete(["key3"]) # No error should be raised
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_str_store_yield_keys() -> None:
|
|
||||||
store = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
|
||||||
|
|
||||||
keys = list(store.yield_keys())
|
|
||||||
assert set(keys) == {"key1", "key2", "key3"}
|
|
||||||
|
|
||||||
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
|
||||||
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
|
||||||
|
|
||||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
|
||||||
assert keys_with_invalid_prefix == []
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_str_store_collection() -> None:
|
|
||||||
"""Test that collections are isolated within a db."""
|
|
||||||
store_1 = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store_2 = SQLStrStore(
|
|
||||||
collection_name=COLLECTION_NAME_2,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
store_1.mset([("key1", "value1"), ("key2", "value2")])
|
|
||||||
store_2.mset([("key3", "value3"), ("key4", "value4")])
|
|
||||||
|
|
||||||
values = store_1.mget(["key1", "key2"])
|
|
||||||
assert values == ["value1", "value2"]
|
|
||||||
values = store_1.mget(["key3", "key4"])
|
|
||||||
assert values == [None, None]
|
|
||||||
|
|
||||||
values = store_2.mget(["key1", "key2"])
|
|
||||||
assert values == [None, None]
|
|
||||||
values = store_2.mget(["key3", "key4"])
|
|
||||||
assert values == ["value3", "value4"]
|
|
||||||
|
|
||||||
store_1.delete_collection()
|
|
||||||
store_2.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_doc_store_mget() -> None:
|
|
||||||
store = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
doc_1 = Document(page_content="value1")
|
|
||||||
doc_2 = Document(page_content="value2")
|
|
||||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == [doc_1, doc_2]
|
|
||||||
|
|
||||||
# Test non-existent key
|
|
||||||
non_existent_value = store.mget(["key3"])
|
|
||||||
assert non_existent_value == [None]
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_doc_store_mset() -> None:
|
|
||||||
store = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
doc_1 = Document(page_content="value1")
|
|
||||||
doc_2 = Document(page_content="value2")
|
|
||||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == [doc_1, doc_2]
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_doc_store_mdelete() -> None:
|
|
||||||
store = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
doc_1 = Document(page_content="value1")
|
|
||||||
doc_2 = Document(page_content="value2")
|
|
||||||
store.mset([("key1", doc_1), ("key2", doc_2)])
|
|
||||||
|
|
||||||
store.mdelete(["key1"])
|
|
||||||
|
|
||||||
values = store.mget(["key1", "key2"])
|
|
||||||
assert values == [None, doc_2]
|
|
||||||
|
|
||||||
# Test deleting non-existent key
|
|
||||||
store.mdelete(["key3"]) # No error should be raised
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_doc_store_yield_keys() -> None:
|
|
||||||
store = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
doc_1 = Document(page_content="value1")
|
|
||||||
doc_2 = Document(page_content="value2")
|
|
||||||
doc_3 = Document(page_content="value3")
|
|
||||||
store.mset([("key1", doc_1), ("key2", doc_2), ("key3", doc_3)])
|
|
||||||
|
|
||||||
keys = list(store.yield_keys())
|
|
||||||
assert set(keys) == {"key1", "key2", "key3"}
|
|
||||||
|
|
||||||
keys_with_prefix = list(store.yield_keys(prefix="key"))
|
|
||||||
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
|
||||||
|
|
||||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
|
||||||
assert keys_with_invalid_prefix == []
|
|
||||||
store.delete_collection()
|
|
||||||
|
|
||||||
|
|
||||||
def test_doc_store_collection() -> None:
|
|
||||||
"""Test that collections are isolated within a db."""
|
|
||||||
store_1 = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
store_2 = SQLDocStore(
|
|
||||||
collection_name=COLLECTION_NAME_2,
|
|
||||||
connection_string=CONNECTION_STRING,
|
|
||||||
pre_delete_collection=True,
|
|
||||||
)
|
|
||||||
doc_1 = Document(page_content="value1")
|
|
||||||
doc_2 = Document(page_content="value2")
|
|
||||||
doc_3 = Document(page_content="value3")
|
|
||||||
doc_4 = Document(page_content="value4")
|
|
||||||
store_1.mset([("key1", doc_1), ("key2", doc_2)])
|
|
||||||
store_2.mset([("key3", doc_3), ("key4", doc_4)])
|
|
||||||
|
|
||||||
values = store_1.mget(["key1", "key2"])
|
|
||||||
assert values == [doc_1, doc_2]
|
|
||||||
values = store_1.mget(["key3", "key4"])
|
|
||||||
assert values == [None, None]
|
|
||||||
|
|
||||||
values = store_2.mget(["key1", "key2"])
|
|
||||||
assert values == [None, None]
|
|
||||||
values = store_2.mget(["key3", "key4"])
|
|
||||||
assert values == [doc_3, doc_4]
|
|
||||||
|
|
||||||
store_1.delete_collection()
|
|
||||||
store_2.delete_collection()
|
|
@ -6,8 +6,6 @@ EXPECTED_ALL = [
|
|||||||
"RedisStore",
|
"RedisStore",
|
||||||
"UpstashRedisByteStore",
|
"UpstashRedisByteStore",
|
||||||
"UpstashRedisStore",
|
"UpstashRedisStore",
|
||||||
"SQLDocStore",
|
|
||||||
"SQLStrStore",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
"""Light weight unit test that attempts to import SQLDocStore/SQLStrStore.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_import_storage() -> None:
|
|
||||||
"""Attempt to import storage modules."""
|
|
||||||
from langchain_community.storage.sql import SQLDocStore, SQLStrStore # noqa
|
|
Loading…
Reference in New Issue
Block a user