mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Add async sql record manager and async indexing API (#10726)
- **Description:** Add support for a SQLRecordManager in async environments. It includes the creation of `RecorManagerAsync` abstract class. - **Issue:** None - **Dependencies:** Optional `aiosqlite`. - **Tag maintainer:** @nfcampos - **Twitter handle:** @jvelezmagic --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
57ade13b2b
commit
a1c7532298
@ -13,13 +13,14 @@ Importantly, this keeps on working even if the content being written is derived
|
||||
via a set of transformations from some source content (e.g., indexing children
|
||||
documents that were derived from parent documents by chunking.)
|
||||
"""
|
||||
from langchain.indexes._api import IndexingResult, index
|
||||
from langchain.indexes._api import IndexingResult, aindex, index
|
||||
from langchain.indexes._sql_record_manager import SQLRecordManager
|
||||
from langchain.indexes.graph import GraphIndexCreator
|
||||
from langchain.indexes.vectorstore import VectorstoreIndexCreator
|
||||
|
||||
__all__ = [
|
||||
# Keep sorted
|
||||
"aindex",
|
||||
"GraphIndexCreator",
|
||||
"index",
|
||||
"IndexingResult",
|
||||
|
@ -7,6 +7,8 @@ import uuid
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
@ -15,6 +17,7 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -36,7 +39,7 @@ def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
|
||||
return uuid.uuid5(NAMESPACE_UUID, hash_value)
|
||||
|
||||
|
||||
def _hash_nested_dict_to_uuid(data: dict) -> uuid.UUID:
|
||||
def _hash_nested_dict_to_uuid(data: dict[Any, Any]) -> uuid.UUID:
|
||||
"""Hashes a nested dictionary and returns the corresponding UUID."""
|
||||
serialized_data = json.dumps(data, sort_keys=True)
|
||||
hash_value = hashlib.sha1(serialized_data.encode("utf-8")).hexdigest()
|
||||
@ -118,6 +121,21 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[List[T]]:
|
||||
"""Utility batching function."""
|
||||
batch: List[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
batch.append(element)
|
||||
|
||||
if len(batch) >= size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
def _get_source_id_assigner(
|
||||
source_id_key: Union[str, Callable[[Document], str], None],
|
||||
) -> Callable[[Document], Union[str, None]]:
|
||||
@ -139,7 +157,7 @@ def _deduplicate_in_order(
|
||||
hashed_documents: Iterable[_HashedDocument],
|
||||
) -> Iterator[_HashedDocument]:
|
||||
"""Deduplicate a list of hashed documents while preserving order."""
|
||||
seen = set()
|
||||
seen: Set[str] = set()
|
||||
|
||||
for hashed_doc in hashed_documents:
|
||||
if hashed_doc.hash_ not in seen:
|
||||
@ -346,3 +364,203 @@ def index(
|
||||
"num_skipped": num_skipped,
|
||||
"num_deleted": num_deleted,
|
||||
}
|
||||
|
||||
|
||||
# Define an asynchronous generator function
|
||||
async def _to_async_iterator(iterator: Iterable[T]) -> AsyncIterator[T]:
|
||||
"""Convert an iterable to an async iterator."""
|
||||
for item in iterator:
|
||||
yield item
|
||||
|
||||
|
||||
async def aindex(
|
||||
docs_source: Union[Iterable[Document], AsyncIterator[Document]],
|
||||
record_manager: RecordManager,
|
||||
vector_store: VectorStore,
|
||||
*,
|
||||
batch_size: int = 100,
|
||||
cleanup: Literal["incremental", "full", None] = None,
|
||||
source_id_key: Union[str, Callable[[Document], str], None] = None,
|
||||
cleanup_batch_size: int = 1_000,
|
||||
) -> IndexingResult:
|
||||
"""Index data from the loader into the vector store.
|
||||
|
||||
Indexing functionality uses a manager to keep track of which documents
|
||||
are in the vector store.
|
||||
|
||||
This allows us to keep track of which documents were updated, and which
|
||||
documents were deleted, which documents should be skipped.
|
||||
|
||||
For the time being, documents are indexed using their hashes, and users
|
||||
are not able to specify the uid of the document.
|
||||
|
||||
IMPORTANT:
|
||||
if auto_cleanup is set to True, the loader should be returning
|
||||
the entire dataset, and not just a subset of the dataset.
|
||||
Otherwise, the auto_cleanup will remove documents that it is not
|
||||
supposed to.
|
||||
|
||||
Args:
|
||||
docs_source: Data loader or iterable of documents to index.
|
||||
record_manager: Timestamped set to keep track of which documents were
|
||||
updated.
|
||||
vector_store: Vector store to index the documents into.
|
||||
batch_size: Batch size to use when indexing.
|
||||
cleanup: How to handle clean up of documents.
|
||||
- Incremental: Cleans up all documents that haven't been updated AND
|
||||
that are associated with source ids that were seen
|
||||
during indexing.
|
||||
Clean up is done continuously during indexing helping
|
||||
to minimize the probability of users seeing duplicated
|
||||
content.
|
||||
- Full: Delete all documents that haven to been returned by the loader.
|
||||
Clean up runs after all documents have been indexed.
|
||||
This means that users may see duplicated content during indexing.
|
||||
- None: Do not delete any documents.
|
||||
source_id_key: Optional key that helps identify the original source
|
||||
of the document.
|
||||
cleanup_batch_size: Batch size to use when cleaning up documents.
|
||||
|
||||
Returns:
|
||||
Indexing result which contains information about how many documents
|
||||
were added, updated, deleted, or skipped.
|
||||
"""
|
||||
|
||||
if cleanup not in {"incremental", "full", None}:
|
||||
raise ValueError(
|
||||
f"cleanup should be one of 'incremental', 'full' or None. "
|
||||
f"Got {cleanup}."
|
||||
)
|
||||
|
||||
if cleanup == "incremental" and source_id_key is None:
|
||||
raise ValueError("Source id key is required when cleanup mode is incremental.")
|
||||
|
||||
# Check that the Vectorstore has required methods implemented
|
||||
methods = ["adelete", "aadd_documents"]
|
||||
|
||||
for method in methods:
|
||||
if not hasattr(vector_store, method):
|
||||
raise ValueError(
|
||||
f"Vectorstore {vector_store} does not have required method {method}"
|
||||
)
|
||||
|
||||
if type(vector_store).adelete == VectorStore.adelete:
|
||||
# Checking if the vectorstore has overridden the default delete method
|
||||
# implementation which just raises a NotImplementedError
|
||||
raise ValueError("Vectorstore has not implemented the delete method")
|
||||
|
||||
if isinstance(docs_source, BaseLoader):
|
||||
raise NotImplementedError(
|
||||
"Not supported yet. Please pass an async iterator of documents."
|
||||
)
|
||||
async_doc_iterator: AsyncIterator[Document]
|
||||
|
||||
if hasattr(docs_source, "__aiter__"):
|
||||
async_doc_iterator = docs_source # type: ignore[assignment]
|
||||
else:
|
||||
async_doc_iterator = _to_async_iterator(docs_source)
|
||||
|
||||
source_id_assigner = _get_source_id_assigner(source_id_key)
|
||||
|
||||
# Mark when the update started.
|
||||
index_start_dt = await record_manager.aget_time()
|
||||
num_added = 0
|
||||
num_skipped = 0
|
||||
num_updated = 0
|
||||
num_deleted = 0
|
||||
|
||||
async for doc_batch in _abatch(batch_size, async_doc_iterator):
|
||||
hashed_docs = list(
|
||||
_deduplicate_in_order(
|
||||
[_HashedDocument.from_document(doc) for doc in doc_batch]
|
||||
)
|
||||
)
|
||||
|
||||
source_ids: Sequence[Optional[str]] = [
|
||||
source_id_assigner(doc) for doc in hashed_docs
|
||||
]
|
||||
|
||||
if cleanup == "incremental":
|
||||
# If the cleanup mode is incremental, source ids are required.
|
||||
for source_id, hashed_doc in zip(source_ids, hashed_docs):
|
||||
if source_id is None:
|
||||
raise ValueError(
|
||||
"Source ids are required when cleanup mode is incremental. "
|
||||
f"Document that starts with "
|
||||
f"content: {hashed_doc.page_content[:100]} was not assigned "
|
||||
f"as source id."
|
||||
)
|
||||
# source ids cannot be None after for loop above.
|
||||
source_ids = cast(Sequence[str], source_ids)
|
||||
|
||||
exists_batch = await record_manager.aexists([doc.uid for doc in hashed_docs])
|
||||
|
||||
# Filter out documents that already exist in the record store.
|
||||
uids: list[str] = []
|
||||
docs_to_index: list[Document] = []
|
||||
|
||||
for hashed_doc, doc_exists in zip(hashed_docs, exists_batch):
|
||||
if doc_exists:
|
||||
# Must be updated to refresh timestamp.
|
||||
await record_manager.aupdate(
|
||||
[hashed_doc.uid], time_at_least=index_start_dt
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
uids.append(hashed_doc.uid)
|
||||
docs_to_index.append(hashed_doc.to_document())
|
||||
|
||||
# Be pessimistic and assume that all vector store write will fail.
|
||||
# First write to vector store
|
||||
if docs_to_index:
|
||||
await vector_store.aadd_documents(docs_to_index, ids=uids)
|
||||
num_added += len(docs_to_index)
|
||||
|
||||
# And only then update the record store.
|
||||
# Update ALL records, even if they already exist since we want to refresh
|
||||
# their timestamp.
|
||||
await record_manager.aupdate(
|
||||
[doc.uid for doc in hashed_docs],
|
||||
group_ids=source_ids,
|
||||
time_at_least=index_start_dt,
|
||||
)
|
||||
|
||||
# If source IDs are provided, we can do the deletion incrementally!
|
||||
|
||||
if cleanup == "incremental":
|
||||
# Get the uids of the documents that were not returned by the loader.
|
||||
|
||||
# mypy isn't good enough to determine that source ids cannot be None
|
||||
# here due to a check that's happening above, so we check again.
|
||||
for source_id in source_ids:
|
||||
if source_id is None:
|
||||
raise AssertionError("Source ids cannot be None here.")
|
||||
|
||||
_source_ids = cast(Sequence[str], source_ids)
|
||||
|
||||
uids_to_delete = await record_manager.alist_keys(
|
||||
group_ids=_source_ids, before=index_start_dt
|
||||
)
|
||||
if uids_to_delete:
|
||||
# Then delete from vector store.
|
||||
await vector_store.adelete(uids_to_delete)
|
||||
# First delete from record store.
|
||||
await record_manager.adelete_keys(uids_to_delete)
|
||||
num_deleted += len(uids_to_delete)
|
||||
|
||||
if cleanup == "full":
|
||||
while uids_to_delete := await record_manager.alist_keys(
|
||||
before=index_start_dt, limit=cleanup_batch_size
|
||||
):
|
||||
# First delete from record store.
|
||||
await vector_store.adelete(uids_to_delete)
|
||||
# Then delete from record manager.
|
||||
await record_manager.adelete_keys(uids_to_delete)
|
||||
num_deleted += len(uids_to_delete)
|
||||
|
||||
return {
|
||||
"num_added": num_added,
|
||||
"num_updated": num_updated,
|
||||
"num_skipped": num_skipped,
|
||||
"num_deleted": num_deleted,
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ allow it to work with a variety of SQL as a backend.
|
||||
import contextlib
|
||||
import decimal
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence, Union
|
||||
|
||||
from sqlalchemy import (
|
||||
URL,
|
||||
@ -28,8 +28,16 @@ from sqlalchemy import (
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
create_engine,
|
||||
delete,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@ -77,9 +85,10 @@ class SQLRecordManager(RecordManager):
|
||||
self,
|
||||
namespace: str,
|
||||
*,
|
||||
engine: Optional[Engine] = None,
|
||||
engine: Optional[Union[Engine, AsyncEngine]] = None,
|
||||
db_url: Union[None, str, URL] = None,
|
||||
engine_kwargs: Optional[Dict[str, Any]] = None,
|
||||
async_mode: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the SQLRecordManager.
|
||||
|
||||
@ -95,6 +104,10 @@ class SQLRecordManager(RecordManager):
|
||||
an SQL Alchemy engine. Default is None.
|
||||
engine_kwargs: Additional keyword arguments
|
||||
to be passed when creating the engine. Default is an empty dictionary.
|
||||
async_mode: Whether to create an async engine.
|
||||
Driver should support async operations.
|
||||
It only applies if db_url is provided.
|
||||
Default is False.
|
||||
|
||||
Raises:
|
||||
ValueError: If both db_url and engine are provided or neither.
|
||||
@ -103,34 +116,71 @@ class SQLRecordManager(RecordManager):
|
||||
super().__init__(namespace=namespace)
|
||||
if db_url is None and engine is None:
|
||||
raise ValueError("Must specify either db_url or engine")
|
||||
|
||||
if db_url is not None and engine is not None:
|
||||
raise ValueError("Must specify either db_url or engine, not both")
|
||||
|
||||
_engine: Union[Engine, AsyncEngine]
|
||||
if db_url:
|
||||
_kwargs = engine_kwargs or {}
|
||||
_engine = create_engine(db_url, **_kwargs)
|
||||
if async_mode:
|
||||
_engine = create_async_engine(db_url, **(engine_kwargs or {}))
|
||||
else:
|
||||
_engine = create_engine(db_url, **(engine_kwargs or {}))
|
||||
elif engine:
|
||||
_engine = engine
|
||||
|
||||
else:
|
||||
raise AssertionError("Something went wrong with configuration of engine.")
|
||||
|
||||
_session_factory: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
|
||||
if isinstance(_engine, AsyncEngine):
|
||||
_session_factory = async_sessionmaker(bind=_engine)
|
||||
else:
|
||||
_session_factory = sessionmaker(bind=_engine)
|
||||
|
||||
self.engine = _engine
|
||||
self.dialect = _engine.dialect.name
|
||||
self.session_factory = sessionmaker(bind=self.engine)
|
||||
self.session_factory = _session_factory
|
||||
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema."""
|
||||
if isinstance(self.engine, AsyncEngine):
|
||||
raise AssertionError("This method is not supported for async engines.")
|
||||
|
||||
Base.metadata.create_all(self.engine)
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema."""
|
||||
|
||||
if not isinstance(self.engine, AsyncEngine):
|
||||
raise AssertionError("This method is not supported for sync engines.")
|
||||
|
||||
async with self.engine.begin() as session:
|
||||
await session.run_sync(Base.metadata.create_all)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_session(self) -> Generator[Session, None, None]:
|
||||
"""Create a session and close it after use."""
|
||||
|
||||
if isinstance(self.session_factory, async_sessionmaker):
|
||||
raise AssertionError("This method is not supported for async engines.")
|
||||
|
||||
session = self.session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a session and close it after use."""
|
||||
|
||||
if not isinstance(self.session_factory, async_sessionmaker):
|
||||
raise AssertionError("This method is not supported for sync engines.")
|
||||
|
||||
async with self.session_factory() as session:
|
||||
yield session
|
||||
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a timestamp.
|
||||
|
||||
@ -161,6 +211,37 @@ class SQLRecordManager(RecordManager):
|
||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||
return dt
|
||||
|
||||
async def aget_time(self) -> float:
|
||||
"""Get the current server time as a timestamp.
|
||||
|
||||
Please note it's critical that time is obtained from the server since
|
||||
we want a monotonic clock.
|
||||
"""
|
||||
async with self._amake_session() as session:
|
||||
# * SQLite specific implementation, can be changed based on dialect.
|
||||
# * For SQLite, unlike unixepoch it will work with older versions of SQLite.
|
||||
# ----
|
||||
# julianday('now'): Julian day number for the current date and time.
|
||||
# The Julian day is a continuous count of days, starting from a
|
||||
# reference date (Julian day number 0).
|
||||
# 2440587.5 - constant represents the Julian day number for January 1, 1970
|
||||
# 86400.0 - constant represents the number of seconds
|
||||
# in a day (24 hours * 60 minutes * 60 seconds)
|
||||
if self.dialect == "sqlite":
|
||||
query = text("SELECT (julianday('now') - 2440587.5) * 86400.0;")
|
||||
elif self.dialect == "postgresql":
|
||||
query = text("SELECT EXTRACT (EPOCH FROM CURRENT_TIMESTAMP);")
|
||||
else:
|
||||
raise NotImplementedError(f"Not implemented for dialect {self.dialect}")
|
||||
|
||||
dt = (await session.execute(query)).scalar_one_or_none()
|
||||
|
||||
if isinstance(dt, decimal.Decimal):
|
||||
dt = float(dt)
|
||||
if not isinstance(dt, float):
|
||||
raise AssertionError(f"Unexpected type for datetime: {type(dt)}")
|
||||
return dt
|
||||
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
@ -236,6 +317,81 @@ class SQLRecordManager(RecordManager):
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the SQLite database."""
|
||||
if group_ids is None:
|
||||
group_ids = [None] * len(keys)
|
||||
|
||||
if len(keys) != len(group_ids):
|
||||
raise ValueError(
|
||||
f"Number of keys ({len(keys)}) does not match number of "
|
||||
f"group_ids ({len(group_ids)})"
|
||||
)
|
||||
|
||||
# Get the current time from the server.
|
||||
# This makes an extra round trip to the server, should not be a big deal
|
||||
# if the batch size is large enough.
|
||||
# Getting the time here helps us compare it against the time_at_least
|
||||
# and raise an error if there is a time sync issue.
|
||||
# Here, we're just being extra careful to minimize the chance of
|
||||
# data loss due to incorrectly deleting records.
|
||||
update_time = await self.aget_time()
|
||||
|
||||
if time_at_least and update_time < time_at_least:
|
||||
# Safeguard against time sync issues
|
||||
raise AssertionError(f"Time sync issue: {update_time} < {time_at_least}")
|
||||
|
||||
records_to_upsert = [
|
||||
{
|
||||
"key": key,
|
||||
"namespace": self.namespace,
|
||||
"updated_at": update_time,
|
||||
"group_id": group_id,
|
||||
}
|
||||
for key, group_id in zip(keys, group_ids)
|
||||
]
|
||||
|
||||
async with self._amake_session() as session:
|
||||
if self.dialect == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = sqlite_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
[UpsertionRecord.key, UpsertionRecord.namespace],
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
),
|
||||
)
|
||||
elif self.dialect == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# Note: uses SQLite insert to make on_conflict_do_update work.
|
||||
# This code needs to be generalized a bit to work with more dialects.
|
||||
insert_stmt = pg_insert(UpsertionRecord).values(records_to_upsert)
|
||||
stmt = insert_stmt.on_conflict_do_update( # type: ignore[attr-defined]
|
||||
"uix_key_namespace", # Name of constraint
|
||||
set_=dict(
|
||||
# attr-defined type ignore
|
||||
updated_at=insert_stmt.excluded.updated_at, # type: ignore
|
||||
group_id=insert_stmt.excluded.group_id, # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect {self.dialect}")
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the SQLite database."""
|
||||
with self._make_session() as session:
|
||||
@ -253,6 +409,26 @@ class SQLRecordManager(RecordManager):
|
||||
found_keys = set(r.key for r in records)
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the given keys exist in the SQLite database."""
|
||||
async with self._amake_session() as session:
|
||||
records = (
|
||||
(
|
||||
await session.execute(
|
||||
select(UpsertionRecord.key).where(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
found_keys = set(records)
|
||||
return [k in found_keys for k in keys]
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
@ -286,6 +462,39 @@ class SQLRecordManager(RecordManager):
|
||||
records = query.all() # type: ignore[attr-defined]
|
||||
return [r.key for r in records]
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the SQLite database based on the provided date range."""
|
||||
async with self._amake_session() as session:
|
||||
query = select(UpsertionRecord.key).filter(
|
||||
UpsertionRecord.namespace == self.namespace
|
||||
)
|
||||
|
||||
# mypy does not recognize .all() or .filter()
|
||||
if after:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at > after
|
||||
)
|
||||
if before:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.updated_at < before
|
||||
)
|
||||
if group_ids:
|
||||
query = query.filter( # type: ignore[attr-defined]
|
||||
UpsertionRecord.group_id.in_(group_ids)
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit) # type: ignore[attr-defined]
|
||||
records = (await session.execute(query)).scalars().all()
|
||||
return list(records)
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete records from the SQLite database."""
|
||||
with self._make_session() as session:
|
||||
@ -297,3 +506,17 @@ class SQLRecordManager(RecordManager):
|
||||
)
|
||||
).delete() # type: ignore[attr-defined]
|
||||
session.commit()
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete records from the SQLite database."""
|
||||
async with self._amake_session() as session:
|
||||
await session.execute(
|
||||
delete(UpsertionRecord).where(
|
||||
and_(
|
||||
UpsertionRecord.key.in_(keys),
|
||||
UpsertionRecord.namespace == self.namespace,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
@ -25,6 +25,10 @@ class RecordManager(ABC):
|
||||
def create_schema(self) -> None:
|
||||
"""Create the database schema for the record manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_schema(self) -> None:
|
||||
"""Create the database schema for the record manager."""
|
||||
|
||||
@abstractmethod
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!
|
||||
@ -36,6 +40,17 @@ class RecordManager(ABC):
|
||||
The current server time as a float timestamp.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aget_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!
|
||||
|
||||
It's important to get this from the server to ensure a monotonic clock,
|
||||
otherwise there may be data loss when cleaning up old documents!
|
||||
|
||||
Returns:
|
||||
The current server time as a float timestamp.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
@ -56,6 +71,26 @@ class RecordManager(ABC):
|
||||
ValueError: If the length of keys doesn't match the length of group_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Upsert records into the database.
|
||||
|
||||
Args:
|
||||
keys: A list of record keys to upsert.
|
||||
group_ids: A list of group IDs corresponding to the keys.
|
||||
time_at_least: if provided, updates should only happen if the
|
||||
updated_at field is at least this time.
|
||||
|
||||
Raises:
|
||||
ValueError: If the length of keys doesn't match the length of group_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
@ -67,6 +102,17 @@ class RecordManager(ABC):
|
||||
A list of boolean values indicating the existence of each key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
"""Check if the provided keys exist in the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to check.
|
||||
|
||||
Returns:
|
||||
A list of boolean values indicating the existence of each key.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_keys(
|
||||
self,
|
||||
@ -88,6 +134,27 @@ class RecordManager(ABC):
|
||||
A list of keys for the matching records.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""List records in the database based on the provided filters.
|
||||
|
||||
Args:
|
||||
before: Filter to list records updated before this time.
|
||||
after: Filter to list records updated after this time.
|
||||
group_ids: Filter to list records with specific group IDs.
|
||||
limit: optional limit on the number of records to return.
|
||||
|
||||
Returns:
|
||||
A list of keys for the matching records.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete specified records from the database.
|
||||
@ -95,3 +162,11 @@ class RecordManager(ABC):
|
||||
Args:
|
||||
keys: A list of keys to delete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
"""Delete specified records from the database.
|
||||
|
||||
Args:
|
||||
keys: A list of keys to delete.
|
||||
"""
|
||||
|
@ -80,6 +80,22 @@ class VectorStore(ABC):
|
||||
|
||||
raise NotImplementedError("delete method must be implemented by subclass.")
|
||||
|
||||
async def adelete(
|
||||
self, ids: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
raise NotImplementedError("delete method must be implemented by subclass.")
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
|
41
libs/langchain/poetry.lock
generated
41
libs/langchain/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@ -226,6 +226,21 @@ files = [
|
||||
[package.dependencies]
|
||||
frozenlist = ">=1.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "aiosqlite"
|
||||
version = "0.19.0"
|
||||
description = "asyncio bridge to the standard sqlite3 module"
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "aiosqlite-0.19.0-py3-none-any.whl", hash = "sha256:edba222e03453e094a3ce605db1b970c4b3376264e56f32e2a4959f948d66a96"},
|
||||
{file = "aiosqlite-0.19.0.tar.gz", hash = "sha256:95ee77b91c8d2808bd08a59fbebf66270e9090c3d92ffbf260dc0db0b979577d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["aiounittest (==1.4.1)", "attribution (==1.6.2)", "black (==23.3.0)", "coverage[toml] (==7.2.3)", "flake8 (==5.0.4)", "flake8-bugbear (==23.3.12)", "flit (==3.7.1)", "mypy (==1.2.0)", "ufmt (==2.1.0)", "usort (==1.0.6)"]
|
||||
docs = ["sphinx (==6.1.3)", "sphinx-mdinclude (==0.5.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "aleph-alpha-client"
|
||||
version = "2.17.0"
|
||||
@ -5768,12 +5783,11 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
||||
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
||||
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
||||
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
||||
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
||||
{version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\" and python_version >= \"3.8\""},
|
||||
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
|
||||
{version = ">=1.17.3", markers = "(platform_system != \"Darwin\" and platform_system != \"Linux\") and python_version >= \"3.8\" and python_version < \"3.9\" or platform_system != \"Darwin\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_machine != \"aarch64\" or platform_machine != \"arm64\" and python_version >= \"3.8\" and python_version < \"3.9\" and platform_system != \"Linux\" or (platform_machine != \"arm64\" and platform_machine != \"aarch64\") and python_version >= \"3.8\" and python_version < \"3.9\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
|
||||
@ -5961,7 +5975,7 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
@ -8577,11 +8591,6 @@ files = [
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"},
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"},
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"},
|
||||
@ -9029,7 +9038,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""}
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
@ -10854,7 +10863,7 @@ cli = ["typer"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "amazon-textract-caller", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-schema-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "xata", "xmltodict"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@ -10864,4 +10873,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "7fbe9a5144717db54413735663870168b00e34deb4f37559e38d62843488adae"
|
||||
content-hash = "8458ce2704b1fcba33f5b6bb8cb3d6fdf4d6b9a2006563f461a3edf7b8dd0d17"
|
||||
|
@ -137,6 +137,7 @@ jsonpatch = "^1.33"
|
||||
timescale-vector = {version = "^0.0.1", optional = true}
|
||||
typer = {version= "^0.9.0", optional = true}
|
||||
anthropic = {version = "^0.3.11", optional = true}
|
||||
aiosqlite = {version = "^0.19.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
@ -314,6 +315,7 @@ cli = [
|
||||
# merge-conflicts
|
||||
extended_testing = [
|
||||
"amazon-textract-caller",
|
||||
"aiosqlite",
|
||||
"assemblyai",
|
||||
"beautifulsoup4",
|
||||
"bibtexparser",
|
||||
|
@ -5,6 +5,7 @@ def test_all() -> None:
|
||||
"""Use to catch obvious breaking changes."""
|
||||
assert __all__ == sorted(__all__, key=str.lower)
|
||||
assert __all__ == [
|
||||
"aindex",
|
||||
"GraphIndexCreator",
|
||||
"index",
|
||||
"IndexingResult",
|
||||
|
@ -1,14 +1,26 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Type
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.indexes import index
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.indexes import aindex, index
|
||||
from langchain.indexes._api import _abatch
|
||||
from langchain.indexes._sql_record_manager import SQLRecordManager
|
||||
from langchain.schema import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VST, VectorStore
|
||||
|
||||
|
||||
@ -28,6 +40,19 @@ class ToyLoader(BaseLoader):
|
||||
"""Load the documents from the source."""
|
||||
return list(self.lazy_load())
|
||||
|
||||
async def alazy_load(
|
||||
self,
|
||||
) -> AsyncIterator[Document]:
|
||||
async def async_generator() -> AsyncIterator[Document]:
|
||||
for document in self.documents:
|
||||
yield document
|
||||
|
||||
return async_generator()
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load the documents from the source."""
|
||||
return [doc async for doc in await self.alazy_load()]
|
||||
|
||||
|
||||
class InMemoryVectorStore(VectorStore):
|
||||
"""In-memory implementation of VectorStore using a dictionary."""
|
||||
@ -42,6 +67,12 @@ class InMemoryVectorStore(VectorStore):
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete the given documents from the store using their IDs."""
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
def add_documents( # type: ignore
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -65,10 +96,33 @@ class InMemoryVectorStore(VectorStore):
|
||||
)
|
||||
self.store[_id] = document
|
||||
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if ids and len(ids) != len(documents):
|
||||
raise ValueError(
|
||||
f"Expected {len(ids)} ids, got {len(documents)} documents."
|
||||
)
|
||||
|
||||
if not ids:
|
||||
raise NotImplementedError("This is not implemented yet.")
|
||||
|
||||
for _id, document in zip(ids, documents):
|
||||
if _id in self.store:
|
||||
raise ValueError(
|
||||
f"Document with uid {_id} already exists in the store."
|
||||
)
|
||||
self.store[_id] = document
|
||||
return list(ids)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add the given texts to the store (insert behavior)."""
|
||||
@ -79,7 +133,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
metadatas: Optional[List[Dict[Any, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
"""Create a vector store from a list of texts."""
|
||||
@ -100,6 +154,19 @@ def record_manager() -> SQLRecordManager:
|
||||
return record_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture # type: ignore
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def arecord_manager() -> SQLRecordManager:
|
||||
"""Timestamped set fixture."""
|
||||
record_manager = SQLRecordManager(
|
||||
"kittens",
|
||||
db_url="sqlite+aiosqlite:///:memory:",
|
||||
async_mode=True,
|
||||
)
|
||||
await record_manager.acreate_schema()
|
||||
return record_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store() -> InMemoryVectorStore:
|
||||
"""Vector store fixture."""
|
||||
@ -140,6 +207,44 @@ def test_indexing_same_content(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aindexing_same_content(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Indexing some content to confirm it gets added only once."""
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert await aindex(await loader.alazy_load(), arecord_manager, vector_store) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
assert len(list(vector_store.store)) == 2
|
||||
|
||||
for _ in range(2):
|
||||
# Run the indexing again
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_index_simple_delete_full(
|
||||
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
@ -215,6 +320,91 @@ def test_index_simple_delete_full(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aindex_simple_delete_full(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Indexing some content to confirm it gets added only once."""
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 1).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="mutated document 1",
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.", # <-- Same as original
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 1,
|
||||
"num_deleted": 1,
|
||||
"num_skipped": 1,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"mutated document 1", "This is another document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_incremental_fails_with_bad_source_ids(
|
||||
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
@ -251,6 +441,49 @@ def test_incremental_fails_with_bad_source_ids(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aincremental_fails_with_bad_source_ids(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Test indexing with incremental deletion strategy."""
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
metadata={"source": "2"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is yet another document.",
|
||||
metadata={"source": None},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# Should raise an error because no source id function was specified
|
||||
await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
source_id_key="source",
|
||||
)
|
||||
|
||||
|
||||
def test_no_delete(
|
||||
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
@ -332,6 +565,89 @@ def test_no_delete(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_ano_delete(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Test indexing without a deletion strategy."""
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
metadata={"source": "2"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
# If we add the same content twice it should be skipped
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="mutated content",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
metadata={"source": "2"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Should result in no updates or deletions!
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup=None,
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 1,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 1,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_incremental_delete(
|
||||
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
@ -436,6 +752,112 @@ def test_incremental_delete(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aincremental_delete(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Test indexing with incremental deletion strategy."""
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.",
|
||||
metadata={"source": "2"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {"This is another document.", "This is a test document."}
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 2,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
# Create 2 documents from the same source all with mutated content
|
||||
loader = ToyLoader(
|
||||
documents=[
|
||||
Document(
|
||||
page_content="mutated document 1",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="mutated document 2",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is another document.", # <-- Same as original
|
||||
metadata={"source": "2"},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Attempt to index again verify that nothing changes
|
||||
with patch.object(
|
||||
arecord_manager, "aget_time", return_value=datetime(2021, 1, 3).timestamp()
|
||||
):
|
||||
assert await aindex(
|
||||
await loader.alazy_load(),
|
||||
arecord_manager,
|
||||
vector_store,
|
||||
cleanup="incremental",
|
||||
source_id_key="source",
|
||||
) == {
|
||||
"num_added": 2,
|
||||
"num_deleted": 1,
|
||||
"num_skipped": 1,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
doc_texts = set(
|
||||
# Ignoring type since doc should be in the store and not a None
|
||||
vector_store.store.get(uid).page_content # type: ignore
|
||||
for uid in vector_store.store
|
||||
)
|
||||
assert doc_texts == {
|
||||
"mutated document 1",
|
||||
"mutated document 2",
|
||||
"This is another document.",
|
||||
}
|
||||
|
||||
|
||||
def test_indexing_with_no_docs(
|
||||
record_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
@ -450,6 +872,24 @@ def test_indexing_with_no_docs(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aindexing_with_no_docs(
|
||||
arecord_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
"""Check edge case when loader returns no new docs."""
|
||||
loader = ToyLoader(documents=[])
|
||||
|
||||
assert await aindex(
|
||||
await loader.alazy_load(), arecord_manager, vector_store, cleanup="full"
|
||||
) == {
|
||||
"num_added": 0,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_deduplication(
|
||||
record_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
@ -474,6 +914,32 @@ def test_deduplication(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_adeduplication(
|
||||
arecord_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
"""Check edge case when loader returns no new docs."""
|
||||
docs = [
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
]
|
||||
|
||||
# Should result in only a single document being added
|
||||
assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 1,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_cleanup_with_different_batchsize(
|
||||
record_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
@ -511,6 +977,45 @@ def test_cleanup_with_different_batchsize(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_cleanup_with_different_batchsize(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Check that we can clean up with different batch size."""
|
||||
docs = [
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": str(d)},
|
||||
)
|
||||
for d in range(1000)
|
||||
]
|
||||
|
||||
assert await aindex(docs, arecord_manager, vector_store, cleanup="full") == {
|
||||
"num_added": 1000,
|
||||
"num_deleted": 0,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
docs = [
|
||||
Document(
|
||||
page_content="Different doc",
|
||||
metadata={"source": str(d)},
|
||||
)
|
||||
for d in range(1001)
|
||||
]
|
||||
|
||||
assert await aindex(
|
||||
docs, arecord_manager, vector_store, cleanup="full", cleanup_batch_size=17
|
||||
) == {
|
||||
"num_added": 1001,
|
||||
"num_deleted": 1000,
|
||||
"num_skipped": 0,
|
||||
"num_updated": 0,
|
||||
}
|
||||
|
||||
|
||||
def test_deduplication_v2(
|
||||
record_manager: SQLRecordManager, vector_store: VectorStore
|
||||
) -> None:
|
||||
@ -547,3 +1052,29 @@ def test_deduplication_v2(
|
||||
[document.page_content for document in vector_store.store.values()]
|
||||
)
|
||||
assert contents == ["1", "2", "3"]
|
||||
|
||||
|
||||
async def _to_async_iter(it: Iterable[Any]) -> AsyncIterator[Any]:
|
||||
"""Convert an iterable to an async iterator."""
|
||||
for i in it:
|
||||
yield i
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abatch() -> None:
|
||||
"""Test the abatch function."""
|
||||
batches = _abatch(5, _to_async_iter(range(12)))
|
||||
assert isinstance(batches, AsyncIterator)
|
||||
assert [batch async for batch in batches] == [
|
||||
[0, 1, 2, 3, 4],
|
||||
[5, 6, 7, 8, 9],
|
||||
[10, 11],
|
||||
]
|
||||
|
||||
batches = _abatch(1, _to_async_iter(range(3)))
|
||||
assert isinstance(batches, AsyncIterator)
|
||||
assert [batch async for batch in batches] == [[0], [1], [2]]
|
||||
|
||||
batches = _abatch(2, _to_async_iter(range(5)))
|
||||
assert isinstance(batches, AsyncIterator)
|
||||
assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]]
|
||||
|
@ -2,6 +2,8 @@ from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
|
||||
from langchain.indexes._sql_record_manager import SQLRecordManager, UpsertionRecord
|
||||
|
||||
@ -15,6 +17,20 @@ def manager() -> SQLRecordManager:
|
||||
return record_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture # type: ignore
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def amanager() -> SQLRecordManager:
|
||||
"""Initialize the test database and yield the TimestampedSet instance."""
|
||||
# Initialize and yield the TimestampedSet instance
|
||||
record_manager = SQLRecordManager(
|
||||
"kittens",
|
||||
db_url="sqlite+aiosqlite:///:memory:",
|
||||
async_mode=True,
|
||||
)
|
||||
await record_manager.acreate_schema()
|
||||
return record_manager
|
||||
|
||||
|
||||
def test_update(manager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
@ -28,6 +44,21 @@ def test_update(manager: SQLRecordManager) -> None:
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aupdate(amanager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == []
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
# Retrieve the records
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
def test_update_timestamp(manager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
@ -119,6 +150,117 @@ def test_update_timestamp(manager: SQLRecordManager) -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aupdate_timestamp(amanager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key1"])
|
||||
|
||||
async with amanager._amake_session() as session:
|
||||
records = (
|
||||
(
|
||||
await session.execute(
|
||||
select(UpsertionRecord).filter(
|
||||
UpsertionRecord.namespace == amanager.namespace
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
assert [
|
||||
{
|
||||
"key": record.key,
|
||||
"namespace": record.namespace,
|
||||
"updated_at": record.updated_at,
|
||||
"group_id": record.group_id,
|
||||
}
|
||||
for record in records
|
||||
] == [
|
||||
{
|
||||
"group_id": None,
|
||||
"key": "key1",
|
||||
"namespace": "kittens",
|
||||
"updated_at": datetime(2021, 1, 2, 0, 0).timestamp(),
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2023, 1, 2).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key1"])
|
||||
|
||||
async with amanager._amake_session() as session:
|
||||
records = (
|
||||
(
|
||||
await session.execute(
|
||||
select(UpsertionRecord).filter(
|
||||
UpsertionRecord.namespace == amanager.namespace
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
assert [
|
||||
{
|
||||
"key": record.key,
|
||||
"namespace": record.namespace,
|
||||
"updated_at": record.updated_at,
|
||||
"group_id": record.group_id,
|
||||
}
|
||||
for record in records
|
||||
] == [
|
||||
{
|
||||
"group_id": None,
|
||||
"key": "key1",
|
||||
"namespace": "kittens",
|
||||
"updated_at": datetime(2023, 1, 2, 0, 0).timestamp(),
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
amanager, "aget_time", return_value=datetime(2023, 2, 2).timestamp()
|
||||
):
|
||||
await amanager.aupdate(["key1"], group_ids=["group1"])
|
||||
|
||||
async with amanager._amake_session() as session:
|
||||
records = (
|
||||
(
|
||||
await session.execute(
|
||||
select(UpsertionRecord).filter(
|
||||
UpsertionRecord.namespace == amanager.namespace
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
assert [
|
||||
{
|
||||
"key": record.key,
|
||||
"namespace": record.namespace,
|
||||
"updated_at": record.updated_at,
|
||||
"group_id": record.group_id,
|
||||
}
|
||||
for record in records
|
||||
] == [
|
||||
{
|
||||
"group_id": "group1",
|
||||
"key": "key1",
|
||||
"namespace": "kittens",
|
||||
"updated_at": datetime(2023, 2, 2, 0, 0).timestamp(),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_update_with_group_ids(manager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
@ -132,6 +274,21 @@ def test_update_with_group_ids(manager: SQLRecordManager) -> None:
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aupdate_with_group_ids(amanager: SQLRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == []
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
# Retrieve the records
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
def test_exists(manager: SQLRecordManager) -> None:
|
||||
"""Test checking if keys exist in the database."""
|
||||
# Insert records
|
||||
@ -147,6 +304,23 @@ def test_exists(manager: SQLRecordManager) -> None:
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aexists(amanager: SQLRecordManager) -> None:
|
||||
"""Test checking if keys exist in the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
# Check if the keys exist in the database
|
||||
exists = await amanager.aexists(keys)
|
||||
assert len(exists) == len(keys)
|
||||
assert exists == [True, True, True]
|
||||
|
||||
exists = await amanager.aexists(["key1", "key4"])
|
||||
assert len(exists) == 2
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
def test_list_keys(manager: SQLRecordManager) -> None:
|
||||
"""Test listing keys based on the provided date range."""
|
||||
# Insert records
|
||||
@ -234,6 +408,98 @@ def test_list_keys(manager: SQLRecordManager) -> None:
|
||||
) == ["key4"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_alist_keys(amanager: SQLRecordManager) -> None:
|
||||
"""Test listing keys based on the provided date range."""
|
||||
# Insert records
|
||||
assert await amanager.alist_keys() == []
|
||||
async with amanager._amake_session() as session:
|
||||
# Add some keys with explicit updated_ats
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key1",
|
||||
updated_at=datetime(2021, 1, 1).timestamp(),
|
||||
namespace="kittens",
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key2",
|
||||
updated_at=datetime(2022, 1, 1).timestamp(),
|
||||
namespace="kittens",
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key3",
|
||||
updated_at=datetime(2023, 1, 1).timestamp(),
|
||||
namespace="kittens",
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key4",
|
||||
group_id="group1",
|
||||
updated_at=datetime(2024, 1, 1).timestamp(),
|
||||
namespace="kittens",
|
||||
)
|
||||
)
|
||||
# Insert keys from a different namespace, these should not be visible!
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key1",
|
||||
updated_at=datetime(2021, 1, 1).timestamp(),
|
||||
namespace="puppies",
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UpsertionRecord(
|
||||
key="key5",
|
||||
updated_at=datetime(2021, 1, 1).timestamp(),
|
||||
namespace="puppies",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Retrieve all keys
|
||||
assert await amanager.alist_keys() == ["key1", "key2", "key3", "key4"]
|
||||
|
||||
# Retrieve keys updated after a certain date
|
||||
assert await amanager.alist_keys(after=datetime(2022, 2, 1).timestamp()) == [
|
||||
"key3",
|
||||
"key4",
|
||||
]
|
||||
|
||||
# Retrieve keys updated after a certain date
|
||||
assert await amanager.alist_keys(before=datetime(2022, 2, 1).timestamp()) == [
|
||||
"key1",
|
||||
"key2",
|
||||
]
|
||||
|
||||
# Retrieve keys updated after a certain date
|
||||
assert await amanager.alist_keys(before=datetime(2019, 2, 1).timestamp()) == []
|
||||
|
||||
# Retrieve keys in a time range
|
||||
assert await amanager.alist_keys(
|
||||
before=datetime(2022, 2, 1).timestamp(),
|
||||
after=datetime(2021, 11, 1).timestamp(),
|
||||
) == ["key2"]
|
||||
|
||||
assert await amanager.alist_keys(group_ids=["group1", "group2"]) == ["key4"]
|
||||
|
||||
# Test multiple filters
|
||||
assert (
|
||||
await amanager.alist_keys(
|
||||
group_ids=["group1", "group2"], before=datetime(2019, 1, 1).timestamp()
|
||||
)
|
||||
== []
|
||||
)
|
||||
assert await amanager.alist_keys(
|
||||
group_ids=["group1", "group2"], after=datetime(2019, 1, 1).timestamp()
|
||||
) == ["key4"]
|
||||
|
||||
|
||||
def test_namespace_is_used(manager: SQLRecordManager) -> None:
|
||||
"""Verify that namespace is taken into account for all operations."""
|
||||
assert manager.namespace == "kittens"
|
||||
@ -261,6 +527,35 @@ def test_namespace_is_used(manager: SQLRecordManager) -> None:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_anamespace_is_used(amanager: SQLRecordManager) -> None:
|
||||
"""Verify that namespace is taken into account for all operations."""
|
||||
assert amanager.namespace == "kittens"
|
||||
async with amanager._amake_session() as session:
|
||||
# Add some keys with explicit updated_ats
|
||||
session.add(UpsertionRecord(key="key1", namespace="kittens"))
|
||||
session.add(UpsertionRecord(key="key2", namespace="kittens"))
|
||||
session.add(UpsertionRecord(key="key1", namespace="puppies"))
|
||||
session.add(UpsertionRecord(key="key3", namespace="puppies"))
|
||||
await session.commit()
|
||||
|
||||
assert await amanager.alist_keys() == ["key1", "key2"]
|
||||
await amanager.adelete_keys(["key1"])
|
||||
assert await amanager.alist_keys() == ["key2"]
|
||||
await amanager.aupdate(["key3"], group_ids=["group3"])
|
||||
|
||||
async with amanager._amake_session() as session:
|
||||
results = (await session.execute(select(UpsertionRecord))).scalars().all()
|
||||
|
||||
assert sorted([(r.namespace, r.key, r.group_id) for r in results]) == [
|
||||
("kittens", "key2", None),
|
||||
("kittens", "key3", "group3"),
|
||||
("puppies", "key1", None),
|
||||
("puppies", "key3", None),
|
||||
]
|
||||
|
||||
|
||||
def test_delete_keys(manager: SQLRecordManager) -> None:
|
||||
"""Test deleting keys from the database."""
|
||||
# Insert records
|
||||
@ -274,3 +569,20 @@ def test_delete_keys(manager: SQLRecordManager) -> None:
|
||||
# Check if the deleted keys are no longer in the database
|
||||
remaining_keys = manager.list_keys()
|
||||
assert remaining_keys == ["key3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_adelete_keys(amanager: SQLRecordManager) -> None:
|
||||
"""Test deleting keys from the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
|
||||
# Delete some keys
|
||||
keys_to_delete = ["key1", "key2"]
|
||||
await amanager.adelete_keys(keys_to_delete)
|
||||
|
||||
# Check if the deleted keys are no longer in the database
|
||||
remaining_keys = await amanager.alist_keys()
|
||||
assert remaining_keys == ["key3"]
|
||||
|
Loading…
Reference in New Issue
Block a user