mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
core[minor],langchain[patch]: Move base indexing interface and logic to core (#20667)
This PR moves the interface and the logic to core. The following changes to namespaces: `indexes` -> `indexing` `indexes._api` -> `indexing.api` Testing code is intentionally duplicated for now since it's testing different implementations of the record manager (in-memory vs. SQL). Common logic will need to be pulled out into the test client. A follow up PR will move the SQL based implementation outside of LangChain.
This commit is contained in:
0
libs/core/tests/unit_tests/indexing/__init__.py
Normal file
0
libs/core/tests/unit_tests/indexing/__init__.py
Normal file
105
libs/core/tests/unit_tests/indexing/in_memory.py
Normal file
105
libs/core/tests/unit_tests/indexing/in_memory.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import time
|
||||
from typing import Dict, List, Optional, Sequence, TypedDict
|
||||
|
||||
from langchain_core.indexing.base import RecordManager
|
||||
|
||||
|
||||
class _Record(TypedDict):
|
||||
group_id: Optional[str]
|
||||
updated_at: float
|
||||
|
||||
|
||||
class InMemoryRecordManager(RecordManager):
|
||||
"""An in-memory record manager for testing purposes."""
|
||||
|
||||
def __init__(self, namespace: str) -> None:
|
||||
super().__init__(namespace)
|
||||
# Each key points to a dictionary
|
||||
# of {'group_id': group_id, 'updated_at': timestamp}
|
||||
self.records: Dict[str, _Record] = {}
|
||||
self.namespace = namespace
|
||||
|
||||
def create_schema(self) -> None:
|
||||
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
||||
|
||||
async def acreate_schema(self) -> None:
|
||||
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
||||
|
||||
def get_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!"""
|
||||
return time.time()
|
||||
|
||||
async def aget_time(self) -> float:
|
||||
"""Get the current server time as a high resolution timestamp!"""
|
||||
return self.get_time()
|
||||
|
||||
def update(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
if group_ids and len(keys) != len(group_ids):
|
||||
raise ValueError("Length of keys must match length of group_ids")
|
||||
for index, key in enumerate(keys):
|
||||
group_id = group_ids[index] if group_ids else None
|
||||
if time_at_least and time_at_least > self.get_time():
|
||||
raise ValueError("time_at_least must be in the past")
|
||||
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}
|
||||
|
||||
async def aupdate(
|
||||
self,
|
||||
keys: Sequence[str],
|
||||
*,
|
||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||
time_at_least: Optional[float] = None,
|
||||
) -> None:
|
||||
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
|
||||
|
||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||
return [key in self.records for key in keys]
|
||||
|
||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||
return self.exists(keys)
|
||||
|
||||
def list_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
result = []
|
||||
for key, data in self.records.items():
|
||||
if before and data["updated_at"] >= before:
|
||||
continue
|
||||
if after and data["updated_at"] <= after:
|
||||
continue
|
||||
if group_ids and data["group_id"] not in group_ids:
|
||||
continue
|
||||
result.append(key)
|
||||
if limit:
|
||||
return result[:limit]
|
||||
return result
|
||||
|
||||
async def alist_keys(
|
||||
self,
|
||||
*,
|
||||
before: Optional[float] = None,
|
||||
after: Optional[float] = None,
|
||||
group_ids: Optional[Sequence[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
return self.list_keys(
|
||||
before=before, after=after, group_ids=group_ids, limit=limit
|
||||
)
|
||||
|
||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||
for key in keys:
|
||||
if key in self.records:
|
||||
del self.records[key]
|
||||
|
||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||
self.delete_keys(keys)
|
50
libs/core/tests/unit_tests/indexing/test_hashed_document.py
Normal file
50
libs/core/tests/unit_tests/indexing/test_hashed_document.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.indexing.api import _HashedDocument
|
||||
|
||||
|
||||
def test_hashed_document_hashing() -> None:
|
||||
hashed_document = _HashedDocument( # type: ignore[call-arg]
|
||||
uid="123", page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
|
||||
)
|
||||
assert isinstance(hashed_document.hash_, str)
|
||||
|
||||
|
||||
def test_hashing_with_missing_content() -> None:
|
||||
"""Check that ValueError is raised if page_content is missing."""
|
||||
with pytest.raises(TypeError):
|
||||
_HashedDocument(
|
||||
metadata={"key": "value"},
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_uid_auto_assigned_to_hash() -> None:
|
||||
"""Test uid is auto-assigned to the hashed_document hash."""
|
||||
hashed_document = _HashedDocument( # type: ignore[call-arg]
|
||||
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
|
||||
)
|
||||
assert hashed_document.uid == hashed_document.hash_
|
||||
|
||||
|
||||
def test_to_document() -> None:
|
||||
"""Test to_document method."""
|
||||
hashed_document = _HashedDocument( # type: ignore[call-arg]
|
||||
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
|
||||
)
|
||||
doc = hashed_document.to_document()
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.page_content == "Lorem ipsum dolor sit amet"
|
||||
assert doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_from_document() -> None:
|
||||
"""Test from document class method."""
|
||||
document = Document(
|
||||
page_content="Lorem ipsum dolor sit amet", metadata={"key": "value"}
|
||||
)
|
||||
|
||||
hashed_document = _HashedDocument.from_document(document)
|
||||
# hash should be deterministic
|
||||
assert hashed_document.hash_ == "fd1dc827-051b-537d-a1fe-1fa043e8b276"
|
||||
assert hashed_document.uid == hashed_document.hash_
|
@@ -0,0 +1,223 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def manager() -> InMemoryRecordManager:
|
||||
"""Initialize the test database and yield the TimestampedSet instance."""
|
||||
# Initialize and yield the TimestampedSet instance
|
||||
record_manager = InMemoryRecordManager(namespace="kittens")
|
||||
record_manager.create_schema()
|
||||
return record_manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def amanager() -> InMemoryRecordManager:
|
||||
"""Initialize the test database and yield the TimestampedSet instance."""
|
||||
# Initialize and yield the TimestampedSet instance
|
||||
record_manager = InMemoryRecordManager(namespace="kittens")
|
||||
await record_manager.acreate_schema()
|
||||
return record_manager
|
||||
|
||||
|
||||
def test_update(manager: InMemoryRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
read_keys = manager.list_keys()
|
||||
assert read_keys == []
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
manager.update(keys)
|
||||
# Retrieve the records
|
||||
read_keys = manager.list_keys()
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
async def test_aupdate(amanager: InMemoryRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == []
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
# Retrieve the records
|
||||
read_keys = await amanager.alist_keys()
|
||||
assert read_keys == ["key1", "key2", "key3"]
|
||||
|
||||
|
||||
def test_update_timestamp(manager: InMemoryRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
manager.update(["key1"])
|
||||
|
||||
assert manager.list_keys() == ["key1"]
|
||||
assert manager.list_keys(before=datetime(2021, 1, 1).timestamp()) == []
|
||||
assert manager.list_keys(after=datetime(2021, 1, 1).timestamp()) == ["key1"]
|
||||
assert manager.list_keys(after=datetime(2021, 1, 3).timestamp()) == []
|
||||
|
||||
# Update the timestamp
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2023, 1, 5).timestamp()
|
||||
):
|
||||
manager.update(["key1"])
|
||||
|
||||
assert manager.list_keys() == ["key1"]
|
||||
assert manager.list_keys(before=datetime(2023, 1, 1).timestamp()) == []
|
||||
assert manager.list_keys(after=datetime(2023, 1, 1).timestamp()) == ["key1"]
|
||||
assert manager.list_keys(after=datetime(2023, 1, 3).timestamp()) == ["key1"]
|
||||
|
||||
|
||||
async def test_aupdate_timestamp(manager: InMemoryRecordManager) -> None:
|
||||
"""Test updating records in the database."""
|
||||
# no keys should be present in the set
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
await manager.aupdate(["key1"])
|
||||
|
||||
assert await manager.alist_keys() == ["key1"]
|
||||
assert await manager.alist_keys(before=datetime(2021, 1, 1).timestamp()) == []
|
||||
assert await manager.alist_keys(after=datetime(2021, 1, 1).timestamp()) == ["key1"]
|
||||
assert await manager.alist_keys(after=datetime(2021, 1, 3).timestamp()) == []
|
||||
|
||||
# Update the timestamp
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2023, 1, 5).timestamp()
|
||||
):
|
||||
await manager.aupdate(["key1"])
|
||||
|
||||
assert await manager.alist_keys() == ["key1"]
|
||||
assert await manager.alist_keys(before=datetime(2023, 1, 1).timestamp()) == []
|
||||
assert await manager.alist_keys(after=datetime(2023, 1, 1).timestamp()) == ["key1"]
|
||||
assert await manager.alist_keys(after=datetime(2023, 1, 3).timestamp()) == ["key1"]
|
||||
|
||||
|
||||
def test_exists(manager: InMemoryRecordManager) -> None:
|
||||
"""Test checking if keys exist in the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
manager.update(keys)
|
||||
# Check if the keys exist in the database
|
||||
exists = manager.exists(keys)
|
||||
assert len(exists) == len(keys)
|
||||
assert exists == [True, True, True]
|
||||
|
||||
exists = manager.exists(["key1", "key4"])
|
||||
assert len(exists) == 2
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
async def test_aexists(amanager: InMemoryRecordManager) -> None:
|
||||
"""Test checking if keys exist in the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
# Check if the keys exist in the database
|
||||
exists = await amanager.aexists(keys)
|
||||
assert len(exists) == len(keys)
|
||||
assert exists == [True, True, True]
|
||||
|
||||
exists = await amanager.aexists(["key1", "key4"])
|
||||
assert len(exists) == 2
|
||||
assert exists == [True, False]
|
||||
|
||||
|
||||
async def test_list_keys(manager: InMemoryRecordManager) -> None:
|
||||
"""Test listing keys based on the provided date range."""
|
||||
# Insert records
|
||||
assert manager.list_keys() == []
|
||||
assert await manager.alist_keys() == []
|
||||
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2021, 1, 2).timestamp()
|
||||
):
|
||||
manager.update(["key1", "key2"])
|
||||
manager.update(["key3"], group_ids=["group1"])
|
||||
manager.update(["key4"], group_ids=["group2"])
|
||||
|
||||
with patch.object(
|
||||
manager, "get_time", return_value=datetime(2021, 1, 10).timestamp()
|
||||
):
|
||||
manager.update(["key5"])
|
||||
|
||||
assert sorted(manager.list_keys()) == ["key1", "key2", "key3", "key4", "key5"]
|
||||
assert sorted(await manager.alist_keys()) == [
|
||||
"key1",
|
||||
"key2",
|
||||
"key3",
|
||||
"key4",
|
||||
"key5",
|
||||
]
|
||||
|
||||
# By group
|
||||
assert manager.list_keys(group_ids=["group1"]) == ["key3"]
|
||||
assert await manager.alist_keys(group_ids=["group1"]) == ["key3"]
|
||||
|
||||
# Before
|
||||
assert sorted(manager.list_keys(before=datetime(2021, 1, 3).timestamp())) == [
|
||||
"key1",
|
||||
"key2",
|
||||
"key3",
|
||||
"key4",
|
||||
]
|
||||
assert sorted(
|
||||
await manager.alist_keys(before=datetime(2021, 1, 3).timestamp())
|
||||
) == [
|
||||
"key1",
|
||||
"key2",
|
||||
"key3",
|
||||
"key4",
|
||||
]
|
||||
|
||||
# After
|
||||
assert sorted(manager.list_keys(after=datetime(2021, 1, 3).timestamp())) == ["key5"]
|
||||
assert sorted(await manager.alist_keys(after=datetime(2021, 1, 3).timestamp())) == [
|
||||
"key5"
|
||||
]
|
||||
|
||||
results = manager.list_keys(limit=1)
|
||||
assert len(results) == 1
|
||||
assert results[0] in ["key1", "key2", "key3", "key4", "key5"]
|
||||
|
||||
results = await manager.alist_keys(limit=1)
|
||||
assert len(results) == 1
|
||||
assert results[0] in ["key1", "key2", "key3", "key4", "key5"]
|
||||
|
||||
|
||||
def test_delete_keys(manager: InMemoryRecordManager) -> None:
|
||||
"""Test deleting keys from the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
manager.update(keys)
|
||||
|
||||
# Delete some keys
|
||||
keys_to_delete = ["key1", "key2"]
|
||||
manager.delete_keys(keys_to_delete)
|
||||
|
||||
# Check if the deleted keys are no longer in the database
|
||||
remaining_keys = manager.list_keys()
|
||||
assert remaining_keys == ["key3"]
|
||||
|
||||
|
||||
async def test_adelete_keys(amanager: InMemoryRecordManager) -> None:
|
||||
"""Test deleting keys from the database."""
|
||||
# Insert records
|
||||
keys = ["key1", "key2", "key3"]
|
||||
await amanager.aupdate(keys)
|
||||
|
||||
# Delete some keys
|
||||
keys_to_delete = ["key1", "key2"]
|
||||
await amanager.adelete_keys(keys_to_delete)
|
||||
|
||||
# Check if the deleted keys are no longer in the database
|
||||
remaining_keys = await amanager.alist_keys()
|
||||
assert remaining_keys == ["key3"]
|
1398
libs/core/tests/unit_tests/indexing/test_indexing.py
Normal file
1398
libs/core/tests/unit_tests/indexing/test_indexing.py
Normal file
File diff suppressed because it is too large
Load Diff
12
libs/core/tests/unit_tests/indexing/test_public_api.py
Normal file
12
libs/core/tests/unit_tests/indexing/test_public_api.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_core.indexing import __all__
|
||||
|
||||
|
||||
def test_all() -> None:
|
||||
"""Use to catch obvious breaking changes."""
|
||||
assert __all__ == sorted(__all__, key=str.lower)
|
||||
assert __all__ == [
|
||||
"aindex",
|
||||
"index",
|
||||
"IndexingResult",
|
||||
"RecordManager",
|
||||
]
|
Reference in New Issue
Block a user