diff --git a/libs/langchain/langchain/indexes/__init__.py b/libs/langchain/langchain/indexes/__init__.py index 792cf30c8a8..84b1ad156be 100644 --- a/libs/langchain/langchain/indexes/__init__.py +++ b/libs/langchain/langchain/indexes/__init__.py @@ -13,13 +13,14 @@ Importantly, this keeps on working even if the content being written is derived via a set of transformations from some source content (e.g., indexing children documents that were derived from parent documents by chunking.) """ -from langchain.indexes._api import IndexingResult, index +from langchain.indexes._api import IndexingResult, aindex, index from langchain.indexes._sql_record_manager import SQLRecordManager from langchain.indexes.graph import GraphIndexCreator from langchain.indexes.vectorstore import VectorstoreIndexCreator __all__ = [ # Keep sorted + "aindex", "GraphIndexCreator", "index", "IndexingResult", diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 88ab4e6a9ab..a386e656341 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -7,6 +7,8 @@ import uuid from itertools import islice from typing import ( Any, + AsyncIterable, + AsyncIterator, Callable, Dict, Iterable, @@ -15,6 +17,7 @@ from typing import ( Literal, Optional, Sequence, + Set, TypedDict, TypeVar, Union, @@ -36,7 +39,7 @@ def _hash_string_to_uuid(input_string: str) -> uuid.UUID: return uuid.uuid5(NAMESPACE_UUID, hash_value) -def _hash_nested_dict_to_uuid(data: dict) -> uuid.UUID: +def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID: """Hashes a nested dictionary and returns the corresponding UUID.""" serialized_data = json.dumps(data, sort_keys=True) hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest() @@ -118,6 +121,21 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: yield chunk +async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]: + """Utility batching function.""" + batch: List[T] = [] + async for element in iterable: + if len(batch) < size: + batch.append(element) + + if len(batch) >= size: + yield batch + batch = [] + + if batch: + yield batch + + def _get_source_id_assigner( source_id_key: Union[str, Callable[[Document], str], None], ) -> Callable[[Document], Union[str, None]]: @@ -139,7 +157,7 @@ def _deduplicate_in_order( hashed_documents: Iterable[_HashedDocument], ) -> Iterator[_HashedDocument]: """Deduplicate a list of hashed documents while preserving order.""" - seen = set() + seen: Set[str] = set() for hashed_doc in hashed_documents: if hashed_doc.hash_ not in seen: @@ -346,3 +364,203 @@ def index( "num_skipped": num_skipped, "num_deleted": num_deleted, } + + +# Define an asynchronous generator function +async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]: + """Convert an iterable to an async iterator.""" + for item in iterator: + yield item + + +async def aindex( + docs_source: Union[Iterable[Document], AsyncIterator[Document]], + record_manager: RecordManager, + vector_store: VectorStore, + *, + batch_size: int = 100, + cleanup: Literal["incremental", "full", None] = None, + source_id_key: Union[str, Callable[[Document], str], None] = None, + cleanup_batch_size: int = 1_000, +) -> IndexingResult: + """Index data from the loader into the vector store. + + Indexing functionality uses a manager to keep track of which documents + are in the vector store. + + This allows us to keep track of which documents were updated, and which + documents were deleted, which documents should be skipped. + + For the time being, documents are indexed using their hashes, and users + are not able to specify the uid of the document. + + IMPORTANT: + if auto_cleanup is set to True, the loader should be returning + the entire dataset, and not just a subset of the dataset. + Otherwise, the auto_cleanup will remove documents that it is not + supposed to. + + Args: + docs_source: Data loader or iterable of documents to index. + record_manager: Timestamped set to keep track of which documents were + updated. + vector_store: Vector store to index the documents into. + batch_size: Batch size to use when indexing. + cleanup: How to handle clean up of documents. + - Incremental: Cleans up all documents that haven't been updated AND + that are associated with source ids that were seen + during indexing. + Clean up is done continuously during indexing helping + to minimize the probability of users seeing duplicated + content. + - Full: Delete all documents that haven to been returned by the loader. + Clean up runs after all documents have been indexed. + This means that users may see duplicated content during indexing. + - None: Do not delete any documents. + source_id_key: Optional key that helps identify the original source + of the document. + cleanup_batch_size: Batch size to use when cleaning up documents. + + Returns: + Indexing result which contains information about how many documents + were added, updated, deleted, or skipped. + """ + + if cleanup not in {"incremental", "full", None}: + raise ValueError( + f"cleanup should be one of 'incremental', 'full' or None. " + f"Got {cleanup}." + ) + + if cleanup == "incremental" and source_id_key is None: + raise ValueError("Source id key is required when cleanup mode is incremental.") + + # Check that the Vectorstore has required methods implemented + methods = ["adelete", "aadd_documents"] + + for method in methods: + if not hasattr(vector_store, method): + raise ValueError( + f"Vectorstore {vector_store} does not have required method {method}" + ) + + if type(vector_store).adelete == VectorStore.adelete: + # Checking if the vectorstore has overridden the default delete method + # implementation which just raises a NotImplementedError + raise ValueError("Vectorstore has not implemented the delete method") + + if isinstance(docs_source, BaseLoader): + raise NotImplementedError( + "Not supported yet. Please pass an async iterator of documents." + ) + async_doc_iterator: AsyncIterator[Document] + + if hasattr(docs_source, "__aiter__"): + async_doc_iterator = docs_source # type: ignore[assignment] + else: + async_doc_iterator = _to_async_iterator(docs_source) + + source_id_assigner = _get_source_id_assigner(source_id_key) + + # Mark when the update started. + index_start_dt = await record_manager.aget_time() + num_added = 0 + num_skipped = 0 + num_updated = 0 + num_deleted = 0 + + async for doc_batch in _abatch(batch_size, async_doc_iterator): + hashed_docs = list( + _deduplicate_in_order( + [_HashedDocument.from_document(doc) for doc in doc_batch] + ) + ) + + source_ids: Sequence[Optional[str]] = [ + source_id_assigner(doc) for doc in hashed_docs + ] + + if cleanup == "incremental": + # If the cleanup mode is incremental, source ids are required. + for source_id, hashed_doc in zip(source_ids, hashed_docs): + if source_id is None: + raise ValueError( + "Source ids are required when cleanup mode is incremental. " + f"Document that starts with " + f"content: {hashed_doc.page_content[:100]} was not assigned " + f"as source id." + ) + # source ids cannot be None after for loop above. + source_ids = cast(Sequence[str], source_ids) + + exists_batch = await record_manager.aexists([doc.uid for doc in hashed_docs]) + + # Filter out documents that already exist in the record store. + uids: list[str] = [] + docs_to_index: list[Document] = [] + + for hashed_doc, doc_exists in zip(hashed_docs, exists_batch): + if doc_exists: + # Must be updated to refresh timestamp. + await record_manager.aupdate( + [hashed_doc.uid], time_at_least=index_start_dt + ) + num_skipped += 1 + continue + uids.append(hashed_doc.uid) + docs_to_index.append(hashed_doc.to_document()) + + # Be pessimistic and assume that all vector store write will fail. + # First write to vector store + if docs_to_index: + await vector_store.aadd_documents(docs_to_index, ids=uids) + num_added += len(docs_to_index) + + # And only then update the record store. + # Update ALL records, even if they already exist since we want to refresh + # their timestamp. + await record_manager.aupdate( + [doc.uid for doc in hashed_docs], + group_ids=source_ids, + time_at_least=index_start_dt, + ) + + # If source IDs are provided, we can do the deletion incrementally! + + if cleanup == "incremental": + # Get the uids of the documents that were not returned by the loader. + + # mypy isn't good enough to determine that source ids cannot be None + # here due to a check that's happening above, so we check again. + for source_id in source_ids: + if source_id is None: + raise AssertionError("Source ids cannot be None here.") + + _source_ids = cast(Sequence[str], source_ids) + + uids_to_delete = await record_manager.alist_keys( + group_ids=_source_ids, before=index_start_dt + ) + if uids_to_delete: + # Then delete from vector store. + await vector_store.adelete(uids_to_delete) + # First delete from record store. + await record_manager.adelete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + if cleanup == "full": + while uids_to_delete := await record_manager.alist_keys( + before=index_start_dt, limit=cleanup_batch_size + ): + # First delete from record store. + await vector_store.adelete(uids_to_delete) + # Then delete from record manager. + await record_manager.adelete_keys(uids_to_delete) + num_deleted += len(uids_to_delete) + + return { + "num_added": num_added, + "num_updated": num_updated, + "num_skipped": num_skipped, + "num_deleted": num_deleted, + } diff --git a/libs/langchain/langchain/indexes/_sql_record_manager.py b/libs/langchain/langchain/indexes/_sql_record_manager.py index d9e579aa5a5..14e2355af22 100644 --- a/libs/langchain/langchain/indexes/_sql_record_manager.py +++ b/libs/langchain/langchain/indexes/_sql_record_manager.py @@ -16,7 +16,7 @@ allow it to work with a variety of SQL as a backend. import contextlib import decimal import uuid -from typing import Any, Dict, Generator, List, Optional, Sequence, Union +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union from sqlalchemy import ( URL, @@ -28,8 +28,16 @@ from sqlalchemy import ( UniqueConstraint, and_, create_engine, + delete, + select, text, ) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session, sessionmaker @@ -77,9 +85,10 @@ class SQLRecordManager(RecordManager): self, namespace: str, *, - engine: Optional[Engine] = None, + engine: Optional[Union[Engine, AsyncEngine]] = None, db_url: Union[None, str, URL] = None, engine_kwargs: Optional[Dict[str, Any]] = None, + async_mode: bool = False, ) -> None: """Initialize the SQLRecordManager. @@ -95,6 +104,10 @@ class SQLRecordManager(RecordManager): an SQL Alchemy engine. Default is None. engine_kwargs: Additional keyword arguments to be passed when creating the engine. Default is an empty dictionary. + async_mode: Whether to create an async engine. + Driver should support async operations. + It only applies if db_url is provided. + Default is False. Raises: ValueError: If both db_url and engine are provided or neither. @@ -103,34 +116,71 @@ class SQLRecordManager(RecordManager): super().__init__(namespace=namespace) if db_url is None and engine is None: raise ValueError("Must specify either db_url or engine") + if db_url is not None and engine is not None: raise ValueError("Must specify either db_url or engine, not both") + _engine: Union[Engine, AsyncEngine] if db_url: - _kwargs = engine_kwargs or {} - _engine = create_engine(db_url, **_kwargs) + if async_mode: + _engine = create_async_engine(db_url, **(engine_kwargs or {})) + else: + _engine = create_engine(db_url, **(engine_kwargs or {})) elif engine: _engine = engine + else: raise AssertionError("Something went wrong with configuration of engine.") + _session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]] + if isinstance(_engine, AsyncEngine): + _session_factory = async_sessionmaker(bind=_engine) + else: + _session_factory = sessionmaker(bind=_engine) + self.engine = _engine self.dialect = _engine.dialect.name - self.session_factory = sessionmaker(bind=self.engine) + self.session_factory = _session_factory def create_schema(self) -> None: """Create the database schema.""" + if isinstance(self.engine, AsyncEngine): + raise AssertionError("This method is not supported for async engines.") + Base.metadata.create_all(self.engine) + async def acreate_schema(self) -> None: + """Create the database schema.""" + + if not isinstance(self.engine, AsyncEngine): + raise AssertionError("This method is not supported for sync engines.") + + async with self.engine.begin() as session: + await session.run_sync(Base.metadata.create_all) + @contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """Create a session and close it after use.""" + + if isinstance(self.session_factory, async_sessionmaker): + raise AssertionError("This method is not supported for async engines.") + session = self.session_factory() try: yield session finally: session.close() + @contextlib.asynccontextmanager + async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: + """Create a session and close it after use.""" + + if not isinstance(self.session_factory, async_sessionmaker): + raise AssertionError("This method is not supported for sync engines.") + + async with self.session_factory() as session: + yield session + def get_time(self) -> float: """Get the current server time as a timestamp. @@ -161,6 +211,37 @@ class SQLRecordManager(RecordManager): raise AssertionError(f"Unexpected type for datetime: {type(dt)}") return dt + async def aget_time(self) -> float: + """Get the current server time as a timestamp. + + Please note it's critical that time is obtained from the server since + we want a monotonic clock. + """ + async with self._amake_session() as session: + # * SQLite specific implementation, can be changed based on dialect. + # * For SQLite, unlike unixepoch it will work with older versions of SQLite. + # ---- + # julianday('now'): Julian day number for the current date and time. + # The Julian day is a continuous count of days, starting from a + # reference date (Julian day number 0). + # 2440587.5 - constant represents the Julian day number for January 1, 1970 + # 86400.0 - constant represents the number of seconds + # in a day (24 hours * 60 minutes * 60 seconds) + if self.dialect == "sqlite": + query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;") + elif self.dialect == "postgresql": + query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);") + else: + raise NotImplementedError(f"Not implemented for dialect {self.dialect}") + + dt = (await session.execute(query)).scalar_one_or_none() + + if isinstance(dt, decimal.Decimal): + dt = float(dt) + if not isinstance(dt, float): + raise AssertionError(f"Unexpected type for datetime: {type(dt)}") + return dt + def update( self, keys: Sequence[str], @@ -236,6 +317,81 @@ class SQLRecordManager(RecordManager): session.execute(stmt) session.commit() + async def aupdate( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + """Upsert records into the SQLite database.""" + if group_ids is None: + group_ids = [None] * len(keys) + + if len(keys) != len(group_ids): + raise ValueError( + f"Number of keys ({len(keys)}) does not match number of " + f"group_ids ({len(group_ids)})" + ) + + # Get the current time from the server. + # This makes an extra round trip to the server, should not be a big deal + # if the batch size is large enough. + # Getting the time here helps us compare it against the time_at_least + # and raise an error if there is a time sync issue. + # Here, we're just being extra careful to minimize the chance of + # data loss due to incorrectly deleting records. + update_time = await self.aget_time() + + if time_at_least and update_time < time_at_least: + # Safeguard against time sync issues + raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}") + + records_to_upsert = [ + { + "key": key, + "namespace": self.namespace, + "updated_at": update_time, + "group_id": group_id, + } + for key, group_id in zip(keys, group_ids) + ] + + async with self._amake_session() as session: + if self.dialect == "sqlite": + from sqlalchemy.dialects.sqlite import insert as sqlite_insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + [UpsertionRecord.key, UpsertionRecord.namespace], + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + elif self.dialect == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + # Note: uses SQLite insert to make on_conflict_do_update work. + # This code needs to be generalized a bit to work with more dialects. + insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert) + stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined] + "uix_key_namespace", # Name of constraint + set_=dict( + # attr-defined type ignore + updated_at=insert_stmt.excluded.updated_at, # type: ignore + group_id=insert_stmt.excluded.group_id, # type: ignore + ), + ) + else: + raise NotImplementedError(f"Unsupported dialect {self.dialect}") + + await session.execute(stmt) + await session.commit() + def exists(self, keys: Sequence[str]) -> List[bool]: """Check if the given keys exist in the SQLite database.""" with self._make_session() as session: @@ -253,6 +409,26 @@ class SQLRecordManager(RecordManager): found_keys = set(r.key for r in records) return [k in found_keys for k in keys] + async def aexists(self, keys: Sequence[str]) -> List[bool]: + """Check if the given keys exist in the SQLite database.""" + async with self._amake_session() as session: + records = ( + ( + await session.execute( + select(UpsertionRecord.key).where( + and_( + UpsertionRecord.key.in_(keys), + UpsertionRecord.namespace == self.namespace, + ) + ) + ) + ) + .scalars() + .all() + ) + found_keys = set(records) + return [k in found_keys for k in keys] + def list_keys( self, *, @@ -286,6 +462,39 @@ class SQLRecordManager(RecordManager): records = query.all() # type: ignore[attr-defined] return [r.key for r in records] + async def alist_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + """List records in the SQLite database based on the provided date range.""" + async with self._amake_session() as session: + query = select(UpsertionRecord.key).filter( + UpsertionRecord.namespace == self.namespace + ) + + # mypy does not recognize .all() or .filter() + if after: + query = query.filter( # type: ignore[attr-defined] + UpsertionRecord.updated_at > after + ) + if before: + query = query.filter( # type: ignore[attr-defined] + UpsertionRecord.updated_at < before + ) + if group_ids: + query = query.filter( # type: ignore[attr-defined] + UpsertionRecord.group_id.in_(group_ids) + ) + + if limit: + query = query.limit(limit) # type: ignore[attr-defined] + records = (await session.execute(query)).scalars().all() + return list(records) + def delete_keys(self, keys: Sequence[str]) -> None: """Delete records from the SQLite database.""" with self._make_session() as session: @@ -297,3 +506,17 @@ class SQLRecordManager(RecordManager): ) ).delete() # type: ignore[attr-defined] session.commit() + + async def adelete_keys(self, keys: Sequence[str]) -> None: + """Delete records from the SQLite database.""" + async with self._amake_session() as session: + await session.execute( + delete(UpsertionRecord).where( + and_( + UpsertionRecord.key.in_(keys), + UpsertionRecord.namespace == self.namespace, + ) + ) + ) + + await session.commit() diff --git a/libs/langchain/langchain/indexes/base.py b/libs/langchain/langchain/indexes/base.py index 69b6e6b5bf4..46ef5bf2efa 100644 --- a/libs/langchain/langchain/indexes/base.py +++ b/libs/langchain/langchain/indexes/base.py @@ -25,6 +25,10 @@ class RecordManager(ABC): def create_schema(self) -> None: """Create the database schema for the record manager.""" + @abstractmethod + async def acreate_schema(self) -> None: + """Create the database schema for the record manager.""" + @abstractmethod def get_time(self) -> float: """Get the current server time as a high resolution timestamp! @@ -36,6 +40,17 @@ class RecordManager(ABC): The current server time as a float timestamp. """ + @abstractmethod + async def aget_time(self) -> float: + """Get the current server time as a high resolution timestamp! + + It's important to get this from the server to ensure a monotonic clock, + otherwise there may be data loss when cleaning up old documents! + + Returns: + The current server time as a float timestamp. + """ + @abstractmethod def update( self, @@ -56,6 +71,26 @@ class RecordManager(ABC): ValueError: If the length of keys doesn't match the length of group_ids. """ + @abstractmethod + async def aupdate( + self, + keys: Sequence[str], + *, + group_ids: Optional[Sequence[Optional[str]]] = None, + time_at_least: Optional[float] = None, + ) -> None: + """Upsert records into the database. + + Args: + keys: A list of record keys to upsert. + group_ids: A list of group IDs corresponding to the keys. + time_at_least: if provided, updates should only happen if the + updated_at field is at least this time. + + Raises: + ValueError: If the length of keys doesn't match the length of group_ids. + """ + @abstractmethod def exists(self, keys: Sequence[str]) -> List[bool]: """Check if the provided keys exist in the database. @@ -67,6 +102,17 @@ class RecordManager(ABC): A list of boolean values indicating the existence of each key. """ + @abstractmethod + async def aexists(self, keys: Sequence[str]) -> List[bool]: + """Check if the provided keys exist in the database. + + Args: + keys: A list of keys to check. + + Returns: + A list of boolean values indicating the existence of each key. + """ + @abstractmethod def list_keys( self, @@ -88,6 +134,27 @@ class RecordManager(ABC): A list of keys for the matching records. """ + @abstractmethod + async def alist_keys( + self, + *, + before: Optional[float] = None, + after: Optional[float] = None, + group_ids: Optional[Sequence[str]] = None, + limit: Optional[int] = None, + ) -> List[str]: + """List records in the database based on the provided filters. + + Args: + before: Filter to list records updated before this time. + after: Filter to list records updated after this time. + group_ids: Filter to list records with specific group IDs. + limit: optional limit on the number of records to return. + + Returns: + A list of keys for the matching records. + """ + @abstractmethod def delete_keys(self, keys: Sequence[str]) -> None: """Delete specified records from the database. @@ -95,3 +162,11 @@ class RecordManager(ABC): Args: keys: A list of keys to delete. """ + + @abstractmethod + async def adelete_keys(self, keys: Sequence[str]) -> None: + """Delete specified records from the database. + + Args: + keys: A list of keys to delete. + """ diff --git a/libs/langchain/langchain/schema/vectorstore.py b/libs/langchain/langchain/schema/vectorstore.py index 861c04ffd9e..23db4473cef 100644 --- a/libs/langchain/langchain/schema/vectorstore.py +++ b/libs/langchain/langchain/schema/vectorstore.py @@ -80,6 +80,22 @@ class VectorStore(ABC): raise NotImplementedError("delete method must be implemented by subclass.") + async def adelete( + self, ids: Optional[List[str]] = None, **kwargs: Any + ) -> Optional[bool]: + """Delete by vector ID or other criteria. + + Args: + ids: List of ids to delete. + **kwargs: Other keyword arguments that subclasses might use. + + Returns: + Optional[bool]: True if deletion is successful, + False otherwise, None if not implemented. + """ + + raise NotImplementedError("delete method must be implemented by subclass.") + async def aadd_texts( self, texts: Iterable[str], diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 2d4ff3c7fb9..4ebd7042c9e 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -226,6 +226,21 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.19.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = true +python-versions = ">=3.7" +files = [ + {file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"}, + {file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"}, +] + +[package.extras] +dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"] +docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"] + [[package]] name = "aleph-alpha-client" version = "2.17.0" @@ -5768,12 +5783,11 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, - {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""}, - {version = ">=1.17.0", markers = "python_version >= \"3.7\""}, - {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, - {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, - {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""}, + {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\" and python_version >= \"3.8\""}, + {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.17.3", markers = "(platform_system != \"Darwin\" and platform_system != \"Linux\") and python_version >= \"3.8\" and python_version < \"3.9\" or platform_system != \"Darwin\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_machine != \"aarch64\" or platform_machine != \"arm64\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_system != \"Linux\" or (platform_machine != \"arm64\" and platform_machine != \"aarch64\") and python_version >= \"3.8\" and python_version < \"3.9\""}, + {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, + {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, ] @@ -5961,7 +5975,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -8577,11 +8591,6 @@ files = [ {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, - {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, @@ -9029,7 +9038,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.2.0" [package.extras] @@ -10854,7 +10863,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -10864,4 +10873,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7fbe9a5144717db54413735663870168b00e34deb4f37559e38d62843488adae" +content-hash = "8458ce2704b1fcba33f5b6bb8cb3d6fdf4d6b9a2006563f461a3edf7b8dd0d17" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index e04aa25500c..30181a5cd71 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -137,6 +137,7 @@ jsonpatch = "^1.33" timescale-vector = {version = "^0.0.1", optional = true} typer = {version= "^0.9.0", optional = true} anthropic = {version = "^0.3.11", optional = true} +aiosqlite = {version = "^0.19.0", optional = true} [tool.poetry.group.test.dependencies] @@ -314,6 +315,7 @@ cli = [ # merge-conflicts extended_testing = [ "amazon-textract-caller", + "aiosqlite", "assemblyai", "beautifulsoup4", "bibtexparser", diff --git a/libs/langchain/tests/unit_tests/indexes/test_api.py b/libs/langchain/tests/unit_tests/indexes/test_api.py index 3104084d096..fa59c71b5bc 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_api.py +++ b/libs/langchain/tests/unit_tests/indexes/test_api.py @@ -5,6 +5,7 @@ def test_all() -> None: """Use to catch obvious breaking changes.""" assert __all__ == sorted(__all__, key=str.lower) assert __all__ == [ + "aindex", "GraphIndexCreator", "index", "IndexingResult", diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index b3e3821b2f7..4e89e4adf92 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -1,14 +1,26 @@ from datetime import datetime -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Type +from typing import ( + Any, + AsyncIterator, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Type, +) from unittest.mock import patch import pytest +import pytest_asyncio from langchain.document_loaders.base import BaseLoader -from langchain.indexes import index +from langchain.embeddings.base import Embeddings +from langchain.indexes import aindex, index +from langchain.indexes._api import _abatch from langchain.indexes._sql_record_manager import SQLRecordManager from langchain.schema import Document -from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VST, VectorStore @@ -28,6 +40,19 @@ class ToyLoader(BaseLoader): """Load the documents from the source.""" return list(self.lazy_load()) + async def alazy_load( + self, + ) -> AsyncIterator[Document]: + async def async_generator() -> AsyncIterator[Document]: + for document in self.documents: + yield document + + return async_generator() + + async def aload(self) -> List[Document]: + """Load the documents from the source.""" + return [doc async for doc in await self.alazy_load()] + class InMemoryVectorStore(VectorStore): """In-memory implementation of VectorStore using a dictionary.""" @@ -42,6 +67,12 @@ class InMemoryVectorStore(VectorStore): for _id in ids: self.store.pop(_id, None) + async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: + """Delete the given documents from the store using their IDs.""" + if ids: + for _id in ids: + self.store.pop(_id, None) + def add_documents( # type: ignore self, documents: Sequence[Document], @@ -65,10 +96,33 @@ class InMemoryVectorStore(VectorStore): ) self.store[_id] = document + async def aadd_documents( + self, + documents: Sequence[Document], + *, + ids: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> List[str]: + if ids and len(ids) != len(documents): + raise ValueError( + f"Expected {len(ids)} ids, got {len(documents)} documents." + ) + + if not ids: + raise NotImplementedError("This is not implemented yet.") + + for _id, document in zip(ids, documents): + if _id in self.store: + raise ValueError( + f"Document with uid {_id} already exists in the store." + ) + self.store[_id] = document + return list(ids) + def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict[Any, Any]]] = None, **kwargs: Any, ) -> List[str]: """Add the given texts to the store (insert behavior).""" @@ -79,7 +133,7 @@ class InMemoryVectorStore(VectorStore): cls: Type[VST], texts: List[str], embedding: Embeddings, - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict[Any, Any]]] = None, **kwargs: Any, ) -> VST: """Create a vector store from a list of texts.""" @@ -100,6 +154,19 @@ def record_manager() -> SQLRecordManager: return record_manager +@pytest_asyncio.fixture # type: ignore +@pytest.mark.requires("aiosqlite") +async def arecord_manager() -> SQLRecordManager: + """Timestamped set fixture.""" + record_manager = SQLRecordManager( + "kittens", + db_url="sqlite+aiosqlite:///:memory:", + async_mode=True, + ) + await record_manager.acreate_schema() + return record_manager + + @pytest.fixture def vector_store() -> InMemoryVectorStore: """Vector store fixture.""" @@ -140,6 +207,44 @@ def test_indexing_same_content( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aindexing_same_content( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert len(list(vector_store.store)) == 2 + + for _ in range(2): + # Run the indexing again + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + def test_index_simple_delete_full( record_manager: SQLRecordManager, vector_store: InMemoryVectorStore ) -> None: @@ -215,6 +320,91 @@ def test_index_simple_delete_full( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aindex_simple_delete_full( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Indexing some content to confirm it gets added only once.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + ), + Document( + page_content="This is another document.", + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp() + ): + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + ), + Document( + page_content="This is another document.", # <-- Same as original + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" + ) == { + "num_added": 1, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"mutated document 1", "This is another document."} + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + def test_incremental_fails_with_bad_source_ids( record_manager: SQLRecordManager, vector_store: InMemoryVectorStore ) -> None: @@ -251,6 +441,49 @@ def test_incremental_fails_with_bad_source_ids( ) +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aincremental_fails_with_bad_source_ids( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + Document( + page_content="This is yet another document.", + metadata={"source": None}, + ), + ] + ) + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + ) + + with pytest.raises(ValueError): + # Should raise an error because no source id function was specified + await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) + + def test_no_delete( record_manager: SQLRecordManager, vector_store: InMemoryVectorStore ) -> None: @@ -332,6 +565,89 @@ def test_no_delete( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_ano_delete( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing without a deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + # If we add the same content twice it should be skipped + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + loader = ToyLoader( + documents=[ + Document( + page_content="mutated content", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + # Should result in no updates or deletions! + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup=None, + source_id_key="source", + ) == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 1, + "num_updated": 0, + } + + def test_incremental_delete( record_manager: SQLRecordManager, vector_store: InMemoryVectorStore ) -> None: @@ -436,6 +752,112 @@ def test_incremental_delete( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aincremental_delete( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with incremental deletion strategy.""" + loader = ToyLoader( + documents=[ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + ) + + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == {"This is another document.", "This is a test document."} + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 2, + "num_updated": 0, + } + + # Create 2 documents from the same source all with mutated content + loader = ToyLoader( + documents=[ + Document( + page_content="mutated document 1", + metadata={"source": "1"}, + ), + Document( + page_content="mutated document 2", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", # <-- Same as original + metadata={"source": "2"}, + ), + ] + ) + + # Attempt to index again verify that nothing changes + with patch.object( + arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp() + ): + assert await aindex( + await loader.alazy_load(), + arecord_manager, + vector_store, + cleanup="incremental", + source_id_key="source", + ) == { + "num_added": 2, + "num_deleted": 1, + "num_skipped": 1, + "num_updated": 0, + } + + doc_texts = set( + # Ignoring type since doc should be in the store and not a None + vector_store.store.get(uid).page_content # type: ignore + for uid in vector_store.store + ) + assert doc_texts == { + "mutated document 1", + "mutated document 2", + "This is another document.", + } + + def test_indexing_with_no_docs( record_manager: SQLRecordManager, vector_store: VectorStore ) -> None: @@ -450,6 +872,24 @@ def test_indexing_with_no_docs( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aindexing_with_no_docs( + arecord_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + loader = ToyLoader(documents=[]) + + assert await aindex( + await loader.alazy_load(), arecord_manager, vector_store, cleanup="full" + ) == { + "num_added": 0, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + def test_deduplication( record_manager: SQLRecordManager, vector_store: VectorStore ) -> None: @@ -474,6 +914,32 @@ def test_deduplication( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_adeduplication( + arecord_manager: SQLRecordManager, vector_store: VectorStore +) -> None: + """Check edge case when loader returns no new docs.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + + # Should result in only a single document being added + assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == { + "num_added": 1, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + def test_cleanup_with_different_batchsize( record_manager: SQLRecordManager, vector_store: VectorStore ) -> None: @@ -511,6 +977,45 @@ def test_cleanup_with_different_batchsize( } +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_async_cleanup_with_different_batchsize( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Check that we can clean up with different batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": str(d)}, + ) + for d in range(1000) + ] + + assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == { + "num_added": 1000, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + docs = [ + Document( + page_content="Different doc", + metadata={"source": str(d)}, + ) + for d in range(1001) + ] + + assert await aindex( + docs, arecord_manager, vector_store, cleanup="full", cleanup_batch_size=17 + ) == { + "num_added": 1001, + "num_deleted": 1000, + "num_skipped": 0, + "num_updated": 0, + } + + def test_deduplication_v2( record_manager: SQLRecordManager, vector_store: VectorStore ) -> None: @@ -547,3 +1052,29 @@ def test_deduplication_v2( [document.page_content for document in vector_store.store.values()] ) assert contents == ["1", "2", "3"] + + +async def _to_async_iter(it: Iterable[Any]) -> AsyncIterator[Any]: + """Convert an iterable to an async iterator.""" + for i in it: + yield i + + +@pytest.mark.asyncio +async def test_abatch() -> None: + """Test the abatch function.""" + batches = _abatch(5, _to_async_iter(range(12))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [ + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11], + ] + + batches = _abatch(1, _to_async_iter(range(3))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [[0], [1], [2]] + + batches = _abatch(2, _to_async_iter(range(5))) + assert isinstance(batches, AsyncIterator) + assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]] diff --git a/libs/langchain/tests/unit_tests/indexes/test_sql_record_manager.py b/libs/langchain/tests/unit_tests/indexes/test_sql_record_manager.py index d7d95212ab5..966ad8b0e5f 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_sql_record_manager.py +++ b/libs/langchain/tests/unit_tests/indexes/test_sql_record_manager.py @@ -2,6 +2,8 @@ from datetime import datetime from unittest.mock import patch import pytest +import pytest_asyncio +from sqlalchemy import select from langchain.indexes._sql_record_manager import SQLRecordManager, UpsertionRecord @@ -15,6 +17,20 @@ def manager() -> SQLRecordManager: return record_manager +@pytest_asyncio.fixture # type: ignore +@pytest.mark.requires("aiosqlite") +async def amanager() -> SQLRecordManager: + """Initialize the test database and yield the TimestampedSet instance.""" + # Initialize and yield the TimestampedSet instance + record_manager = SQLRecordManager( + "kittens", + db_url="sqlite+aiosqlite:///:memory:", + async_mode=True, + ) + await record_manager.acreate_schema() + return record_manager + + def test_update(manager: SQLRecordManager) -> None: """Test updating records in the database.""" # no keys should be present in the set @@ -28,6 +44,21 @@ def test_update(manager: SQLRecordManager) -> None: assert read_keys == ["key1", "key2", "key3"] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aupdate(amanager: SQLRecordManager) -> None: + """Test updating records in the database.""" + # no keys should be present in the set + read_keys = await amanager.alist_keys() + assert read_keys == [] + # Insert records + keys = ["key1", "key2", "key3"] + await amanager.aupdate(keys) + # Retrieve the records + read_keys = await amanager.alist_keys() + assert read_keys == ["key1", "key2", "key3"] + + def test_update_timestamp(manager: SQLRecordManager) -> None: """Test updating records in the database.""" # no keys should be present in the set @@ -119,6 +150,117 @@ def test_update_timestamp(manager: SQLRecordManager) -> None: ] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aupdate_timestamp(amanager: SQLRecordManager) -> None: + """Test updating records in the database.""" + # no keys should be present in the set + with patch.object( + amanager, "aget_time", return_value=datetime(2021, 1, 2).timestamp() + ): + await amanager.aupdate(["key1"]) + + async with amanager._amake_session() as session: + records = ( + ( + await session.execute( + select(UpsertionRecord).filter( + UpsertionRecord.namespace == amanager.namespace + ) + ) + ) + .scalars() + .all() + ) + + assert [ + { + "key": record.key, + "namespace": record.namespace, + "updated_at": record.updated_at, + "group_id": record.group_id, + } + for record in records + ] == [ + { + "group_id": None, + "key": "key1", + "namespace": "kittens", + "updated_at": datetime(2021, 1, 2, 0, 0).timestamp(), + } + ] + + with patch.object( + amanager, "aget_time", return_value=datetime(2023, 1, 2).timestamp() + ): + await amanager.aupdate(["key1"]) + + async with amanager._amake_session() as session: + records = ( + ( + await session.execute( + select(UpsertionRecord).filter( + UpsertionRecord.namespace == amanager.namespace + ) + ) + ) + .scalars() + .all() + ) + + assert [ + { + "key": record.key, + "namespace": record.namespace, + "updated_at": record.updated_at, + "group_id": record.group_id, + } + for record in records + ] == [ + { + "group_id": None, + "key": "key1", + "namespace": "kittens", + "updated_at": datetime(2023, 1, 2, 0, 0).timestamp(), + } + ] + + with patch.object( + amanager, "aget_time", return_value=datetime(2023, 2, 2).timestamp() + ): + await amanager.aupdate(["key1"], group_ids=["group1"]) + + async with amanager._amake_session() as session: + records = ( + ( + await session.execute( + select(UpsertionRecord).filter( + UpsertionRecord.namespace == amanager.namespace + ) + ) + ) + .scalars() + .all() + ) + + assert [ + { + "key": record.key, + "namespace": record.namespace, + "updated_at": record.updated_at, + "group_id": record.group_id, + } + for record in records + ] == [ + { + "group_id": "group1", + "key": "key1", + "namespace": "kittens", + "updated_at": datetime(2023, 2, 2, 0, 0).timestamp(), + } + ] + + def test_update_with_group_ids(manager: SQLRecordManager) -> None: """Test updating records in the database.""" # no keys should be present in the set @@ -132,6 +274,21 @@ def test_update_with_group_ids(manager: SQLRecordManager) -> None: assert read_keys == ["key1", "key2", "key3"] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aupdate_with_group_ids(amanager: SQLRecordManager) -> None: + """Test updating records in the database.""" + # no keys should be present in the set + read_keys = await amanager.alist_keys() + assert read_keys == [] + # Insert records + keys = ["key1", "key2", "key3"] + await amanager.aupdate(keys) + # Retrieve the records + read_keys = await amanager.alist_keys() + assert read_keys == ["key1", "key2", "key3"] + + def test_exists(manager: SQLRecordManager) -> None: """Test checking if keys exist in the database.""" # Insert records @@ -147,6 +304,23 @@ def test_exists(manager: SQLRecordManager) -> None: assert exists == [True, False] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_aexists(amanager: SQLRecordManager) -> None: + """Test checking if keys exist in the database.""" + # Insert records + keys = ["key1", "key2", "key3"] + await amanager.aupdate(keys) + # Check if the keys exist in the database + exists = await amanager.aexists(keys) + assert len(exists) == len(keys) + assert exists == [True, True, True] + + exists = await amanager.aexists(["key1", "key4"]) + assert len(exists) == 2 + assert exists == [True, False] + + def test_list_keys(manager: SQLRecordManager) -> None: """Test listing keys based on the provided date range.""" # Insert records @@ -234,6 +408,98 @@ def test_list_keys(manager: SQLRecordManager) -> None: ) == ["key4"] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_alist_keys(amanager: SQLRecordManager) -> None: + """Test listing keys based on the provided date range.""" + # Insert records + assert await amanager.alist_keys() == [] + async with amanager._amake_session() as session: + # Add some keys with explicit updated_ats + session.add( + UpsertionRecord( + key="key1", + updated_at=datetime(2021, 1, 1).timestamp(), + namespace="kittens", + ) + ) + session.add( + UpsertionRecord( + key="key2", + updated_at=datetime(2022, 1, 1).timestamp(), + namespace="kittens", + ) + ) + session.add( + UpsertionRecord( + key="key3", + updated_at=datetime(2023, 1, 1).timestamp(), + namespace="kittens", + ) + ) + session.add( + UpsertionRecord( + key="key4", + group_id="group1", + updated_at=datetime(2024, 1, 1).timestamp(), + namespace="kittens", + ) + ) + # Insert keys from a different namespace, these should not be visible! + session.add( + UpsertionRecord( + key="key1", + updated_at=datetime(2021, 1, 1).timestamp(), + namespace="puppies", + ) + ) + session.add( + UpsertionRecord( + key="key5", + updated_at=datetime(2021, 1, 1).timestamp(), + namespace="puppies", + ) + ) + await session.commit() + + # Retrieve all keys + assert await amanager.alist_keys() == ["key1", "key2", "key3", "key4"] + + # Retrieve keys updated after a certain date + assert await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp()) == [ + "key3", + "key4", + ] + + # Retrieve keys updated after a certain date + assert await amanager.alist_keys(before=datetime(2022, 2, 1).timestamp()) == [ + "key1", + "key2", + ] + + # Retrieve keys updated after a certain date + assert await amanager.alist_keys(before=datetime(2019, 2, 1).timestamp()) == [] + + # Retrieve keys in a time range + assert await amanager.alist_keys( + before=datetime(2022, 2, 1).timestamp(), + after=datetime(2021, 11, 1).timestamp(), + ) == ["key2"] + + assert await amanager.alist_keys(group_ids=["group1", "group2"]) == ["key4"] + + # Test multiple filters + assert ( + await amanager.alist_keys( + group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp() + ) + == [] + ) + assert await amanager.alist_keys( + group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp() + ) == ["key4"] + + def test_namespace_is_used(manager: SQLRecordManager) -> None: """Verify that namespace is taken into account for all operations.""" assert manager.namespace == "kittens" @@ -261,6 +527,35 @@ def test_namespace_is_used(manager: SQLRecordManager) -> None: ] +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_anamespace_is_used(amanager: SQLRecordManager) -> None: + """Verify that namespace is taken into account for all operations.""" + assert amanager.namespace == "kittens" + async with amanager._amake_session() as session: + # Add some keys with explicit updated_ats + session.add(UpsertionRecord(key="key1", namespace="kittens")) + session.add(UpsertionRecord(key="key2", namespace="kittens")) + session.add(UpsertionRecord(key="key1", namespace="puppies")) + session.add(UpsertionRecord(key="key3", namespace="puppies")) + await session.commit() + + assert await amanager.alist_keys() == ["key1", "key2"] + await amanager.adelete_keys(["key1"]) + assert await amanager.alist_keys() == ["key2"] + await amanager.aupdate(["key3"], group_ids=["group3"]) + + async with amanager._amake_session() as session: + results = (await session.execute(select(UpsertionRecord))).scalars().all() + + assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [ + ("kittens", "key2", None), + ("kittens", "key3", "group3"), + ("puppies", "key1", None), + ("puppies", "key3", None), + ] + + def test_delete_keys(manager: SQLRecordManager) -> None: """Test deleting keys from the database.""" # Insert records @@ -274,3 +569,20 @@ def test_delete_keys(manager: SQLRecordManager) -> None: # Check if the deleted keys are no longer in the database remaining_keys = manager.list_keys() assert remaining_keys == ["key3"] + + +@pytest.mark.asyncio +@pytest.mark.requires("aiosqlite") +async def test_adelete_keys(amanager: SQLRecordManager) -> None: + """Test deleting keys from the database.""" + # Insert records + keys = ["key1", "key2", "key3"] + await amanager.aupdate(keys) + + # Delete some keys + keys_to_delete = ["key1", "key2"] + await amanager.adelete_keys(keys_to_delete) + + # Check if the deleted keys are no longer in the database + remaining_keys = await amanager.alist_keys() + assert remaining_keys == ["key3"]