Add indexing support (#9614)

This PR introduces a persistence layer to help with indexing workflows
into
vectostores.

The indexing code helps users to:

1. Avoid writing duplicated content into the vectostore
2. Avoid over-writing content if it's unchanged

Importantly, this keeps on working even if the content being written is
derived
via a set of transformations from some source content (e.g., indexing
children
documents that were derived from parent documents by chunking.)

The two main components are:

1. Persistence layer that keeps track of which keys were updated and
when.
Keeping track of the timestamp of updates, allows to clean up old
content
   safely, and with minimal complexity.
2. HashedDocument which is used to hash the contents (including
metadata) of
   the documents. We rely on the hashes for identifying duplicates.


The indexing code works with **ANY** document loader. To add
transformations
to the documents, users for now can add a custom document loader
that composes an existing loader together with document transformers.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Eugene Yurtsev
2023-08-23 21:41:38 -04:00
committed by GitHub
parent c215481531
commit b88dfcb42a
9 changed files with 2460 additions and 2 deletions

View File

@@ -1,5 +1,28 @@
"""**Index** utilities."""
"""Code to support various indexing workflows.
Provides code to:
* Create knowledge graphs from data.
* Support indexing workflows from LangChain data loaders to vectorstores.
For indexing workflows, this code is used to avoid writing duplicated content
into the vectostore and to avoid over-writing content if it's unchanged.
Importantly, this keeps on working even if the content being written is derived
via a set of transformations from some source content (e.g., indexing children
documents that were derived from parent documents by chunking.)
"""
from langchain.indexes._api import IndexingResult, index
from langchain.indexes._sql_record_manager import SQLRecordManager
from langchain.indexes.graph import GraphIndexCreator
from langchain.indexes.vectorstore import VectorstoreIndexCreator
__all__ = ["GraphIndexCreator", "VectorstoreIndexCreator"]
__all__ = [
# Keep sorted
"GraphIndexCreator",
"index",
"IndexingResult",
"SQLRecordManager",
"VectorstoreIndexCreator",
]

View File

@@ -0,0 +1,346 @@
"""Module contains logic for indexing documents into vector stores."""
from __future__ import annotations
import hashlib
import json
import uuid
from itertools import islice
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Sequence,
TypedDict,
TypeVar,
Union,
cast,
)
from langchain.document_loaders.base import BaseLoader
from langchain.indexes.base import NAMESPACE_UUID, RecordManager
from langchain.pydantic_v1 import root_validator
from langchain.schema import Document
from langchain.vectorstores.base import VectorStore
T = TypeVar("T")
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
"""Hashes a string and returns the corresponding UUID."""
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _hash_nested_dict_to_uuid(data: dict) -> uuid.UUID:
"""Hashes a nested dictionary and returns the corresponding UUID."""
serialized_data = json.dumps(data, sort_keys=True)
hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
class _HashedDocument(Document):
"""A hashed document with a unique ID."""
uid: str
hash_: str
"""The hash of the document including content and metadata."""
content_hash: str
"""The hash of the document content."""
metadata_hash: str
"""The hash of the document metadata."""
@root_validator(pre=True)
def calculate_hashes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Root validator to calculate content and metadata hash."""
content = values.get("page_content", "")
metadata = values.get("metadata", {})
forbidden_keys = ("hash_", "content_hash", "metadata_hash")
for key in forbidden_keys:
if key in metadata:
raise ValueError(
f"Metadata cannot contain key {key} as it "
f"is reserved for internal use."
)
content_hash = str(_hash_string_to_uuid(content))
try:
metadata_hash = str(_hash_nested_dict_to_uuid(metadata))
except Exception as e:
raise ValueError(
f"Failed to hash metadata: {e}. "
f"Please use a dict that can be serialized using json."
)
values["content_hash"] = content_hash
values["metadata_hash"] = metadata_hash
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
_uid = values.get("uid", None)
if _uid is None:
values["uid"] = values["hash_"]
return values
def to_document(self) -> Document:
"""Return a Document object."""
return Document(
page_content=self.page_content,
metadata=self.metadata,
)
@classmethod
def from_document(
cls, document: Document, *, uid: Optional[str] = None
) -> _HashedDocument:
"""Create a HashedDocument from a Document."""
return cls(
uid=uid,
page_content=document.page_content,
metadata=document.metadata,
)
def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function."""
it = iter(iterable)
while True:
chunk = list(islice(it, size))
if not chunk:
return
yield chunk
def _get_source_id_assigner(
source_id_key: Union[str, Callable[[Document], str], None],
) -> Callable[[Document], Union[str, None]]:
"""Get the source id from the document."""
if source_id_key is None:
return lambda doc: None
elif isinstance(source_id_key, str):
return lambda doc: doc.metadata[source_id_key]
elif callable(source_id_key):
return source_id_key
else:
raise ValueError(
f"source_id_key should be either None, a string or a callable. "
f"Got {source_id_key} of type {type(source_id_key)}."
)
def _deduplicate_in_order(
hashed_documents: Iterable[_HashedDocument],
) -> Iterator[_HashedDocument]:
"""Deduplicate a list of hashed documents while preserving order."""
seen = set()
for hashed_doc in hashed_documents:
if hashed_doc.hash_ not in seen:
seen.add(hashed_doc.hash_)
yield hashed_doc
# PUBLIC API
class IndexingResult(TypedDict):
"""Return a detailed a breakdown of the result of the indexing operation."""
num_added: int
"""Number of added documents."""
num_updated: int
"""Number of updated documents because they were not up to date."""
num_deleted: int
"""Number of deleted documents."""
num_skipped: int
"""Number of skipped documents because they were already up to date."""
def index(
docs_source: Union[BaseLoader, Iterable[Document]],
record_manager: RecordManager,
vector_store: VectorStore,
*,
batch_size: int = 100,
delete_mode: Literal["incremental", "full", None] = None,
source_id_key: Union[str, Callable[[Document], str], None] = None,
) -> IndexingResult:
"""Index data from the loader into the vector store.
Indexing functionality uses a manager to keep track of which documents
are in the vector store.
This allows us to keep track of which documents were updated, and which
documents were deleted, which documents should be skipped.
For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document.
IMPORTANT:
if auto_cleanup is set to True, the loader should be returning
the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not
supposed to.
Args:
docs_source: Data loader or iterable of documents to index.
record_manager: Timestamped set to keep track of which documents were
updated.
vector_store: Vector store to index the documents into.
batch_size: Batch size to use when indexing.
delete_mode: How to handle clean up of documents.
- Incremental: Cleans up all documents that haven't been updated AND
that are associated with source ids that were seen
during indexing.
Clean up is done continuously during indexing helping
to minimize the probability of users seeing duplicated
content.
- Full: Delete all documents that haven to been returned by the loader.
Clean up runs after all documents have been indexed.
This means that users may see duplicated content during indexing.
- None: Do not delete any documents.
source_id_key: Optional key that helps identify the original source
of the document.
Returns:
Indexing result which contains information about how many documents
were added, updated, deleted, or skipped.
"""
if delete_mode not in {"incremental", "full", None}:
raise ValueError(
f"delete_mode should be one of 'incremental', 'full' or None. "
f"Got {delete_mode}."
)
if delete_mode == "incremental" and source_id_key is None:
raise ValueError("Source id key is required when delete mode is incremental.")
# Check that the Vectorstore has required methods implemented
methods = ["delete", "add_documents"]
for method in methods:
if not hasattr(vector_store, method):
raise ValueError(
f"Vectorstore {vector_store} does not have required method {method}"
)
if type(vector_store).delete == VectorStore.delete:
# Checking if the vectorstore has overridden the default delete method
# implementation which just raises a NotImplementedError
raise ValueError("Vectorstore has not implemented the delete method")
if isinstance(docs_source, BaseLoader):
try:
doc_iterator = docs_source.lazy_load()
except NotImplementedError:
doc_iterator = iter(docs_source.load())
else:
doc_iterator = iter(docs_source)
source_id_assigner = _get_source_id_assigner(source_id_key)
# Mark when the update started.
index_start_dt = record_manager.get_time()
num_added = 0
num_skipped = 0
num_updated = 0
num_deleted = 0
for doc_batch in _batch(batch_size, doc_iterator):
hashed_docs = list(
_deduplicate_in_order(
[_HashedDocument.from_document(doc) for doc in doc_batch]
)
)
source_ids: Sequence[Optional[str]] = [
source_id_assigner(doc) for doc in hashed_docs
]
if delete_mode == "incremental":
# If the delete mode is incremental, source ids are required.
for source_id, hashed_doc in zip(source_ids, hashed_docs):
if source_id is None:
raise ValueError(
"Source ids are required when delete mode is incremental. "
f"Document that starts with "
f"content: {hashed_doc.page_content[:100]} was not assigned "
f"as source id."
)
# source ids cannot be None after for loop above.
source_ids = cast(Sequence[str], source_ids) # type: ignore[assignment]
exists_batch = record_manager.exists([doc.uid for doc in hashed_docs])
# Filter out documents that already exist in the record store.
uids = []
docs_to_index = []
for doc, hashed_doc, doc_exists in zip(doc_batch, hashed_docs, exists_batch):
if doc_exists:
# Must be updated to refresh timestamp.
record_manager.update([hashed_doc.uid], time_at_least=index_start_dt)
num_skipped += 1
continue
uids.append(hashed_doc.uid)
docs_to_index.append(doc)
# Be pessimistic and assume that all vector store write will fail.
# First write to vector store
if docs_to_index:
vector_store.add_documents(docs_to_index, ids=uids)
num_added += len(docs_to_index)
# And only then update the record store.
# Update ALL records, even if they already exist since we want to refresh
# their timestamp.
record_manager.update(
[doc.uid for doc in hashed_docs],
group_ids=source_ids,
time_at_least=index_start_dt,
)
# If source IDs are provided, we can do the deletion incrementally!
if delete_mode == "incremental":
# Get the uids of the documents that were not returned by the loader.
# mypy isn't good enough to determine that source ids cannot be None
# here due to a check that's happening above, so we check again.
for source_id in source_ids:
if source_id is None:
raise AssertionError("Source ids cannot be None here.")
_source_ids = cast(Sequence[str], source_ids)
uids_to_delete = record_manager.list_keys(
group_ids=_source_ids, before=index_start_dt
)
if uids_to_delete:
# Then delete from vector store.
vector_store.delete(uids_to_delete)
# First delete from record store.
record_manager.delete_keys(uids_to_delete)
num_deleted += len(uids_to_delete)
if delete_mode == "full":
uids_to_delete = record_manager.list_keys(before=index_start_dt)
if uids_to_delete:
# Then delete from vector store.
vector_store.delete(uids_to_delete)
# First delete from record store.
record_manager.delete_keys(uids_to_delete)
num_deleted = len(uids_to_delete)
return {
"num_added": num_added,
"num_updated": num_updated,
"num_skipped": num_skipped,
"num_deleted": num_deleted,
}

View File

@@ -0,0 +1,265 @@
"""Implementation of a record management layer in SQLAlchemy.
The management layer uses SQLAlchemy to track upserted records.
Currently, this layer only works with SQLite; hopwever, should be adaptable
to other SQL implementations with minimal effort.
Currently, includes an implementation that uses SQLAlchemy which should
allow it to work with a variety of SQL as a backend.
* Each key is associated with an updated_at field.
* This filed is updated whenever the key is updated.
* Keys can be listed based on the updated at field.
* Keys can be deleted.
"""
import contextlib
import uuid
from typing import Any, Dict, Generator, List, Optional, Sequence
from sqlalchemy import (
Column,
Engine,
Float,
Index,
String,
UniqueConstraint,
and_,
create_engine,
text,
)
from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker
from langchain.indexes.base import RecordManager
Base = declarative_base()
class UpsertionRecord(Base): # type: ignore[valid-type,misc]
"""Table used to keep track of when a key was last updated."""
# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "upsertion_record"
uuid = Column(
String,
index=True,
default=lambda: str(uuid.uuid4()),
primary_key=True,
nullable=False,
)
key = Column(String, index=True)
# Using a non-normalized representation to handle `namespace` attribute.
# If the need arises, this attribute can be pulled into a separate Collection
# table at some time later.
namespace = Column(String, index=True, nullable=False)
group_id = Column(String, index=True, nullable=True)
# The timestamp associated with the last record upsertion.
updated_at = Column(Float, index=True)
__table_args__ = (
UniqueConstraint("key", "namespace", name="uix_key_namespace"),
Index("ix_key_namespace", "key", "namespace"),
)
class SQLRecordManager(RecordManager):
"""A SQL Alchemy based implementation of the record manager."""
def __init__(
self,
namespace: str,
*,
engine: Optional[Engine] = None,
db_url: Optional[str] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the SQLRecordManager.
This class serves as a manager persistence layer that uses an SQL
backend to track upserted records. You should specify either a db_url
to create an engine or provide an existing engine.
Args:
namespace: The namespace associated with this record manager.
engine: An already existing SQL Alchemy engine.
Default is None.
db_url: A database connection string used to create
an SQL Alchemy engine. Default is None.
engine_kwargs: Additional keyword arguments
to be passed when creating the engine. Default is an empty dictionary.
Raises:
ValueError: If both db_url and engine are provided or neither.
AssertionError: If something unexpected happens during engine configuration.
"""
super().__init__(namespace=namespace)
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")
if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")
if db_url:
_kwargs = engine_kwargs or {}
_engine = create_engine(db_url, **_kwargs)
elif engine:
_engine = engine
else:
raise AssertionError("Something went wrong with configuration of engine.")
self.engine = _engine
self.session_factory = sessionmaker(bind=self.engine)
def create_schema(self) -> None:
"""Create the database schema."""
Base.metadata.create_all(self.engine)
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a session and close it after use."""
session = self.session_factory()
try:
yield session
finally:
session.close()
def get_time(self) -> float:
"""Get the current server time as a timestamp.
Please note it's critical that time is obtained from the server since
we want a monotonic clock.
"""
with self._make_session() as session:
# * SQLite specific implementation, can be changed based on dialect.
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
# ----
# julianday('now'): Julian day number for the current date and time.
# The Julian day is a continuous count of days, starting from a
# reference date (Julian day number 0).
# 2440587.5 - constant represents the Julian day number for January 1, 1970
# 86400.0 - constant represents the number of seconds
# in a day (24 hours * 60 minutes * 60 seconds)
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
dt = session.execute(query).scalar()
if not isinstance(dt, float):
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
return dt
def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the SQLite database."""
if group_ids is None:
group_ids = [None] * len(keys)
if len(keys) != len(group_ids):
raise ValueError(
f"Number of keys ({len(keys)}) does not match number of "
f"group_ids ({len(group_ids)})"
)
# Get the current time from the server.
# This makes an extra round trip to the server, should not be a big deal
# if the batch size is large enough.
# Getting the time here helps us compare it against the time_at_least
# and raise an error if there is a time sync issue.
# Here, we're just being extra careful to minimize the chance of
# data loss due to incorrectly deleting records.
update_time = self.get_time()
if time_at_least and update_time < time_at_least:
# Safeguard against time sync issues
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
records_to_upsert = [
{
"key": key,
"namespace": self.namespace,
"updated_at": update_time,
"group_id": group_id,
}
for key, group_id in zip(keys, group_ids)
]
with self._make_session() as session:
# Note: uses SQLite insert to make on_conflict_do_update work.
# This code needs to be generalized a bit to work with more dialects.
insert_stmt = insert(UpsertionRecord).values(records_to_upsert)
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
[UpsertionRecord.key, UpsertionRecord.namespace],
set_=dict(
# attr-defined type ignore
updated_at=insert_stmt.excluded.updated_at, # type: ignore
group_id=insert_stmt.excluded.group_id, # type: ignore
),
)
session.execute(stmt)
session.commit()
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the given keys exist in the SQLite database."""
with self._make_session() as session:
records = (
# mypy does not recognize .all()
session.query(UpsertionRecord.key) # type: ignore[attr-defined]
.filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
)
.all()
)
found_keys = set(r.key for r in records)
return [k in found_keys for k in keys]
def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
) -> List[str]:
"""List records in the SQLite database based on the provided date range."""
with self._make_session() as session:
query = session.query(UpsertionRecord).filter(
UpsertionRecord.namespace == self.namespace
)
# mypy does not recognize .all() or .filter()
if after:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at > after
)
if before:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.updated_at < before
)
if group_ids:
query = query.filter( # type: ignore[attr-defined]
UpsertionRecord.group_id.in_(group_ids)
)
records = query.all() # type: ignore[attr-defined]
return [r.key for r in records]
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete records from the SQLite database."""
with self._make_session() as session:
# mypy does not recognize .delete()
session.query(UpsertionRecord).filter(
and_(
UpsertionRecord.key.in_(keys),
UpsertionRecord.namespace == self.namespace,
)
).delete() # type: ignore[attr-defined]
session.commit()

View File

@@ -0,0 +1,95 @@
from __future__ import annotations
import uuid
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence
NAMESPACE_UUID = uuid.UUID(int=1984)
class RecordManager(ABC):
"""An abstract base class representing the interface for a record manager."""
def __init__(
self,
namespace: str,
) -> None:
"""Initialize the record manager.
Args:
namespace (str): The namespace for the record manager.
"""
self.namespace = namespace
@abstractmethod
def create_schema(self) -> None:
"""Create the database schema for the record manager."""
@abstractmethod
def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!
It's important to get this from the server to ensure a monotonic clock,
otherwise there may be data loss when cleaning up old documents!
Returns:
The current server time as a float timestamp.
"""
@abstractmethod
def update(
self,
keys: Sequence[str],
*,
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the database.
Args:
keys: A list of record keys to upsert.
group_ids: A list of group IDs corresponding to the keys.
time_at_least: if provided, updates should only happen if the
updated_at field is at least this time.
Raises:
ValueError: If the length of keys doesn't match the length of group_ids.
"""
@abstractmethod
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database.
Args:
keys: A list of keys to check.
Returns:
A list of boolean values indicating the existence of each key.
"""
@abstractmethod
def list_keys(
self,
*,
before: Optional[float] = None,
after: Optional[float] = None,
group_ids: Optional[Sequence[str]] = None,
) -> List[str]:
"""List records in the database based on the provided filters.
Args:
before: Filter to list records updated before this time.
after: Filter to list records updated after this time.
group_ids: Filter to list records with specific group IDs.
Returns:
A list of keys for the matching records.
"""
@abstractmethod
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database.
Args:
keys: A list of keys to delete.
"""

View File

@@ -0,0 +1,13 @@
from langchain.indexes import __all__
def test_all() -> None:
"""Use to catch obvious breaking changes."""
assert __all__ == sorted(__all__, key=str.lower)
assert __all__ == [
"GraphIndexCreator",
"index",
"IndexingResult",
"SQLRecordManager",
"VectorstoreIndexCreator",
]

View File

@@ -0,0 +1,50 @@
import pytest
from langchain.indexes._api import _HashedDocument
from langchain.schema import Document
def test_hashed_document_hashing() -> None:
hashed_document = _HashedDocument(
uid="123", page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
assert isinstance(hashed_document.hash_, str)
def test_hashing_with_missing_content() -> None:
"""Check that ValueError is raised if page_content is missing."""
with pytest.raises(ValueError):
_HashedDocument(
metadata={"key": "value"},
)
def test_uid_auto_assigned_to_hash() -> None:
"""Test uid is auto-assigned to the hashed_document hash."""
hashed_document = _HashedDocument(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
assert hashed_document.uid == hashed_document.hash_
def test_to_document() -> None:
"""Test to_document method."""
hashed_document = _HashedDocument(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
doc = hashed_document.to_document()
assert isinstance(doc, Document)
assert doc.page_content == "Lorem ipsum dolor sit amet"
assert doc.metadata == {"key": "value"}
def test_from_document() -> None:
"""Test from document class method."""
document = Document(
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
)
hashed_document = _HashedDocument.from_document(document)
# hash should be deterministic
assert hashed_document.hash_ == "fd1dc827-051b-537d-a1fe-1fa043e8b276"
assert hashed_document.uid == hashed_document.hash_

View File

@@ -0,0 +1,474 @@
from datetime import datetime
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Type
from unittest.mock import patch
import pytest
from langchain.document_loaders.base import BaseLoader
from langchain.embeddings.base import Embeddings
from langchain.indexes import index
from langchain.indexes._sql_record_manager import SQLRecordManager
from langchain.schema import Document
from langchain.vectorstores.base import VST, VectorStore
class ToyLoader(BaseLoader):
"""Toy loader that always returns the same documents."""
def __init__(self, documents: Sequence[Document]) -> None:
"""Initialize with the documents to return."""
self.documents = documents
def lazy_load(
self,
) -> Iterator[Document]:
yield from self.documents
def load(self) -> List[Document]:
"""Load the documents from the source."""
return list(self.lazy_load())
class InMemoryVectorStore(VectorStore):
"""In-memory implementation of VectorStore using a dictionary."""
def __init__(self) -> None:
"""Vector store interface for testing things in memory."""
self.store: Dict[str, Document] = {}
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
def add_documents( # type: ignore
self,
documents: Sequence[Document],
*,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> None:
"""Add the given documents to the store (insert behavior)."""
if ids and len(ids) != len(documents):
raise ValueError(
f"Expected {len(ids)} ids, got {len(documents)} documents."
)
if not ids:
raise NotImplementedError("This is not implemented yet.")
for _id, document in zip(ids, documents):
if _id in self.store:
raise ValueError(
f"Document with uid {_id} already exists in the store."
)
self.store[_id] = document
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Add the given texts to the store (insert behavior)."""
raise NotImplementedError()
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
"""Create a vector store from a list of texts."""
raise NotImplementedError()
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Find the most similar documents to the given query."""
raise NotImplementedError()
@pytest.fixture
def record_manager() -> SQLRecordManager:
"""Timestamped set fixture."""
record_manager = SQLRecordManager("kittens", db_url="sqlite:///:memory:")
record_manager.create_schema()
return record_manager
@pytest.fixture
def vector_store() -> InMemoryVectorStore:
"""Vector store fixture."""
return InMemoryVectorStore()
def test_indexing_same_content(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Indexing some content to confirm it gets added only once."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
),
Document(
page_content="This is another document.",
),
]
)
assert index(loader, record_manager, vector_store) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
assert len(list(vector_store.store)) == 2
for _ in range(2):
# Run the indexing again
assert index(loader, record_manager, vector_store) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
def test_index_simple_delete_full(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Indexing some content to confirm it gets added only once."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
),
Document(
page_content="This is another document.",
),
]
)
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert index(loader, record_manager, vector_store, delete_mode="full") == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 1).timestamp()
):
assert index(loader, record_manager, vector_store, delete_mode="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
loader = ToyLoader(
documents=[
Document(
page_content="mutated document 1",
),
Document(
page_content="This is another document.", # <-- Same as original
),
]
)
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(loader, record_manager, vector_store, delete_mode="full") == {
"num_added": 1,
"num_deleted": 1,
"num_skipped": 1,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"mutated document 1", "This is another document."}
# Attempt to index again verify that nothing changes
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(loader, record_manager, vector_store, delete_mode="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
def test_incremental_fails_with_bad_source_ids(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with incremental deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
Document(
page_content="This is yet another document.",
metadata={"source": None},
),
]
)
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
index(loader, record_manager, vector_store, delete_mode="incremental")
with pytest.raises(ValueError):
# Should raise an error because no source id function was specified
index(
loader,
record_manager,
vector_store,
delete_mode="incremental",
source_id_key="source",
)
def test_no_delete(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing without a deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode=None,
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
# If we add the same content twice it should be skipped
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode=None,
source_id_key="source",
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
loader = ToyLoader(
documents=[
Document(
page_content="mutated content",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
# Should result in no updates or deletions!
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode=None,
source_id_key="source",
) == {
"num_added": 1,
"num_deleted": 0,
"num_skipped": 1,
"num_updated": 0,
}
def test_incremental_delete(
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with incremental deletion strategy."""
loader = ToyLoader(
documents=[
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is another document.",
metadata={"source": "2"},
),
]
)
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode="incremental",
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {"This is another document.", "This is a test document."}
# Attempt to index again verify that nothing changes
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode="incremental",
source_id_key="source",
) == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 2,
"num_updated": 0,
}
# Create 2 documents from the same source all with mutated content
loader = ToyLoader(
documents=[
Document(
page_content="mutated document 1",
metadata={"source": "1"},
),
Document(
page_content="mutated document 2",
metadata={"source": "1"},
),
Document(
page_content="This is another document.", # <-- Same as original
metadata={"source": "2"},
),
]
)
# Attempt to index again verify that nothing changes
with patch.object(
record_manager, "get_time", return_value=datetime(2021, 1, 3).timestamp()
):
assert index(
loader,
record_manager,
vector_store,
delete_mode="incremental",
source_id_key="source",
) == {
"num_added": 2,
"num_deleted": 1,
"num_skipped": 1,
"num_updated": 0,
}
doc_texts = set(
# Ignoring type since doc should be in the store and not a None
vector_store.store.get(uid).page_content # type: ignore
for uid in vector_store.store
)
assert doc_texts == {
"mutated document 1",
"mutated document 2",
"This is another document.",
}
def test_indexing_with_no_docs(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
"""Check edge case when loader returns no new docs."""
loader = ToyLoader(documents=[])
assert index(loader, record_manager, vector_store, delete_mode="full") == {
"num_added": 0,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}
def test_deduplication(
record_manager: SQLRecordManager, vector_store: VectorStore
) -> None:
"""Check edge case when loader returns no new docs."""
docs = [
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
Document(
page_content="This is a test document.",
metadata={"source": "1"},
),
]
# Should result in only a single document being added
assert index(docs, record_manager, vector_store, delete_mode="full") == {
"num_added": 1,
"num_deleted": 0,
"num_skipped": 0,
"num_updated": 0,
}

View File

@@ -0,0 +1,276 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from langchain.indexes._sql_record_manager import SQLRecordManager, UpsertionRecord
@pytest.fixture()
def manager() -> SQLRecordManager:
"""Initialize the test database and yield the TimestampedSet instance."""
# Initialize and yield the TimestampedSet instance
record_manager = SQLRecordManager("kittens", db_url="sqlite:///:memory:")
record_manager.create_schema()
return record_manager
def test_update(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
read_keys = manager.list_keys()
assert read_keys == []
# Insert records
keys = ["key1", "key2", "key3"]
manager.update(keys)
# Retrieve the records
read_keys = manager.list_keys()
assert read_keys == ["key1", "key2", "key3"]
def test_update_timestamp(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
with patch.object(
manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
):
manager.update(["key1"])
with manager._make_session() as session:
records = (
session.query(UpsertionRecord)
.filter(UpsertionRecord.namespace == manager.namespace)
.all() # type: ignore[attr-defined]
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": None,
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2021, 1, 2, 0, 0).timestamp(),
}
]
with patch.object(
manager, "get_time", return_value=datetime(2023, 1, 2).timestamp()
):
manager.update(["key1"])
with manager._make_session() as session:
records = (
session.query(UpsertionRecord)
.filter(UpsertionRecord.namespace == manager.namespace)
.all() # type: ignore[attr-defined]
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": None,
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2023, 1, 2, 0, 0).timestamp(),
}
]
with patch.object(
manager, "get_time", return_value=datetime(2023, 2, 2).timestamp()
):
manager.update(["key1"], group_ids=["group1"])
with manager._make_session() as session:
records = (
session.query(UpsertionRecord)
.filter(UpsertionRecord.namespace == manager.namespace)
.all() # type: ignore[attr-defined]
)
assert [
{
"key": record.key,
"namespace": record.namespace,
"updated_at": record.updated_at,
"group_id": record.group_id,
}
for record in records
] == [
{
"group_id": "group1",
"key": "key1",
"namespace": "kittens",
"updated_at": datetime(2023, 2, 2, 0, 0).timestamp(),
}
]
def test_update_with_group_ids(manager: SQLRecordManager) -> None:
"""Test updating records in the database."""
# no keys should be present in the set
read_keys = manager.list_keys()
assert read_keys == []
# Insert records
keys = ["key1", "key2", "key3"]
manager.update(keys)
# Retrieve the records
read_keys = manager.list_keys()
assert read_keys == ["key1", "key2", "key3"]
def test_exists(manager: SQLRecordManager) -> None:
"""Test checking if keys exist in the database."""
# Insert records
keys = ["key1", "key2", "key3"]
manager.update(keys)
# Check if the keys exist in the database
exists = manager.exists(keys)
assert len(exists) == len(keys)
assert exists == [True, True, True]
exists = manager.exists(["key1", "key4"])
assert len(exists) == 2
assert exists == [True, False]
def test_list_keys(manager: SQLRecordManager) -> None:
"""Test listing keys based on the provided date range."""
# Insert records
assert manager.list_keys() == []
with manager._make_session() as session:
# Add some keys with explicit updated_ats
session.add(
UpsertionRecord(
key="key1",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key2",
updated_at=datetime(2022, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key3",
updated_at=datetime(2023, 1, 1).timestamp(),
namespace="kittens",
)
)
session.add(
UpsertionRecord(
key="key4",
group_id="group1",
updated_at=datetime(2024, 1, 1).timestamp(),
namespace="kittens",
)
)
# Insert keys from a different namespace, these should not be visible!
session.add(
UpsertionRecord(
key="key1",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="puppies",
)
)
session.add(
UpsertionRecord(
key="key5",
updated_at=datetime(2021, 1, 1).timestamp(),
namespace="puppies",
)
)
session.commit()
# Retrieve all keys
assert manager.list_keys() == ["key1", "key2", "key3", "key4"]
# Retrieve keys updated after a certain date
assert manager.list_keys(after=datetime(2022, 2, 1).timestamp()) == ["key3", "key4"]
# Retrieve keys updated after a certain date
assert manager.list_keys(before=datetime(2022, 2, 1).timestamp()) == [
"key1",
"key2",
]
# Retrieve keys updated after a certain date
assert manager.list_keys(before=datetime(2019, 2, 1).timestamp()) == []
# Retrieve keys in a time range
assert manager.list_keys(
before=datetime(2022, 2, 1).timestamp(),
after=datetime(2021, 11, 1).timestamp(),
) == ["key2"]
assert manager.list_keys(group_ids=["group1", "group2"]) == ["key4"]
# Test multiple filters
assert (
manager.list_keys(
group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp()
)
== []
)
assert manager.list_keys(
group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp()
) == ["key4"]
def test_namespace_is_used(manager: SQLRecordManager) -> None:
"""Verify that namespace is taken into account for all operations."""
assert manager.namespace == "kittens"
with manager._make_session() as session:
# Add some keys with explicit updated_ats
session.add(UpsertionRecord(key="key1", namespace="kittens"))
session.add(UpsertionRecord(key="key2", namespace="kittens"))
session.add(UpsertionRecord(key="key1", namespace="puppies"))
session.add(UpsertionRecord(key="key3", namespace="puppies"))
session.commit()
assert manager.list_keys() == ["key1", "key2"]
manager.delete_keys(["key1"])
assert manager.list_keys() == ["key2"]
manager.update(["key3"], group_ids=["group3"])
with manager._make_session() as session:
results = session.query(UpsertionRecord).all()
assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [
("kittens", "key2", None),
("kittens", "key3", "group3"),
("puppies", "key1", None),
("puppies", "key3", None),
]
def test_delete_keys(manager: SQLRecordManager) -> None:
"""Test deleting keys from the database."""
# Insert records
keys = ["key1", "key2", "key3"]
manager.update(keys)
# Delete some keys
keys_to_delete = ["key1", "key2"]
manager.delete_keys(keys_to_delete)
# Check if the deleted keys are no longer in the database
remaining_keys = manager.list_keys()
assert remaining_keys == ["key3"]