mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
core[minor]: Adds an in-memory implementation of RecordManager (#13200)
**Description:** langchain offers three technologies to save data: - [vectorstore](https://python.langchain.com/docs/modules/data_connection/vectorstores/) - [docstore](https://js.langchain.com/docs/api/schema/classes/Docstore) - [record manager](https://python.langchain.com/docs/modules/data_connection/indexing) If you want to combine these technologies in a sample persistence stategy you need a common implementation for each. `DocStore` propose `InMemoryDocstore`. We propose the class `MemoryRecordManager` to complete the system. This is the prelude to another full-request, which needs a consistent combination of persistence components. **Tag maintainer:** @baskaryan **Twitter handle:** @pprados --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -5,11 +5,12 @@ a vectorstore while avoiding duplicated content and over-writing content
|
|||||||
if it's unchanged.
|
if it's unchanged.
|
||||||
"""
|
"""
|
||||||
from langchain_core.indexing.api import IndexingResult, aindex, index
|
from langchain_core.indexing.api import IndexingResult, aindex, index
|
||||||
from langchain_core.indexing.base import RecordManager
|
from langchain_core.indexing.base import InMemoryRecordManager, RecordManager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"aindex",
|
"aindex",
|
||||||
"index",
|
"index",
|
||||||
"IndexingResult",
|
"IndexingResult",
|
||||||
|
"InMemoryRecordManager",
|
||||||
"RecordManager",
|
"RecordManager",
|
||||||
]
|
]
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Sequence
|
from typing import Dict, List, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
|
|
||||||
class RecordManager(ABC):
|
class RecordManager(ABC):
|
||||||
@@ -215,3 +216,104 @@ class RecordManager(ABC):
|
|||||||
Args:
|
Args:
|
||||||
keys: A list of keys to delete.
|
keys: A list of keys to delete.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _Record(TypedDict):
|
||||||
|
group_id: Optional[str]
|
||||||
|
updated_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryRecordManager(RecordManager):
|
||||||
|
"""An in-memory record manager for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, namespace: str) -> None:
|
||||||
|
super().__init__(namespace)
|
||||||
|
# Each key points to a dictionary
|
||||||
|
# of {'group_id': group_id, 'updated_at': timestamp}
|
||||||
|
self.records: Dict[str, _Record] = {}
|
||||||
|
self.namespace = namespace
|
||||||
|
|
||||||
|
def create_schema(self) -> None:
|
||||||
|
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
||||||
|
|
||||||
|
async def acreate_schema(self) -> None:
|
||||||
|
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
||||||
|
|
||||||
|
def get_time(self) -> float:
|
||||||
|
"""Get the current server time as a high resolution timestamp!"""
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
async def aget_time(self) -> float:
|
||||||
|
"""Get the current server time as a high resolution timestamp!"""
|
||||||
|
return self.get_time()
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
keys: Sequence[str],
|
||||||
|
*,
|
||||||
|
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||||
|
time_at_least: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
if group_ids and len(keys) != len(group_ids):
|
||||||
|
raise ValueError("Length of keys must match length of group_ids")
|
||||||
|
for index, key in enumerate(keys):
|
||||||
|
group_id = group_ids[index] if group_ids else None
|
||||||
|
if time_at_least and time_at_least > self.get_time():
|
||||||
|
raise ValueError("time_at_least must be in the past")
|
||||||
|
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}
|
||||||
|
|
||||||
|
async def aupdate(
|
||||||
|
self,
|
||||||
|
keys: Sequence[str],
|
||||||
|
*,
|
||||||
|
group_ids: Optional[Sequence[Optional[str]]] = None,
|
||||||
|
time_at_least: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
|
||||||
|
|
||||||
|
def exists(self, keys: Sequence[str]) -> List[bool]:
|
||||||
|
return [key in self.records for key in keys]
|
||||||
|
|
||||||
|
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
||||||
|
return self.exists(keys)
|
||||||
|
|
||||||
|
def list_keys(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
before: Optional[float] = None,
|
||||||
|
after: Optional[float] = None,
|
||||||
|
group_ids: Optional[Sequence[str]] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
result = []
|
||||||
|
for key, data in self.records.items():
|
||||||
|
if before and data["updated_at"] >= before:
|
||||||
|
continue
|
||||||
|
if after and data["updated_at"] <= after:
|
||||||
|
continue
|
||||||
|
if group_ids and data["group_id"] not in group_ids:
|
||||||
|
continue
|
||||||
|
result.append(key)
|
||||||
|
if limit:
|
||||||
|
return result[:limit]
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def alist_keys(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
before: Optional[float] = None,
|
||||||
|
after: Optional[float] = None,
|
||||||
|
group_ids: Optional[Sequence[str]] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
return self.list_keys(
|
||||||
|
before=before, after=after, group_ids=group_ids, limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_keys(self, keys: Sequence[str]) -> None:
|
||||||
|
for key in keys:
|
||||||
|
if key in self.records:
|
||||||
|
del self.records[key]
|
||||||
|
|
||||||
|
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
||||||
|
self.delete_keys(keys)
|
||||||
|
@@ -1,105 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Dict, List, Optional, Sequence, TypedDict
|
|
||||||
|
|
||||||
from langchain_core.indexing.base import RecordManager
|
|
||||||
|
|
||||||
|
|
||||||
class _Record(TypedDict):
|
|
||||||
group_id: Optional[str]
|
|
||||||
updated_at: float
|
|
||||||
|
|
||||||
|
|
||||||
class InMemoryRecordManager(RecordManager):
|
|
||||||
"""An in-memory record manager for testing purposes."""
|
|
||||||
|
|
||||||
def __init__(self, namespace: str) -> None:
|
|
||||||
super().__init__(namespace)
|
|
||||||
# Each key points to a dictionary
|
|
||||||
# of {'group_id': group_id, 'updated_at': timestamp}
|
|
||||||
self.records: Dict[str, _Record] = {}
|
|
||||||
self.namespace = namespace
|
|
||||||
|
|
||||||
def create_schema(self) -> None:
|
|
||||||
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
|
||||||
|
|
||||||
async def acreate_schema(self) -> None:
|
|
||||||
"""In-memory schema creation is simply ensuring the structure is initialized."""
|
|
||||||
|
|
||||||
def get_time(self) -> float:
|
|
||||||
"""Get the current server time as a high resolution timestamp!"""
|
|
||||||
return time.time()
|
|
||||||
|
|
||||||
async def aget_time(self) -> float:
|
|
||||||
"""Get the current server time as a high resolution timestamp!"""
|
|
||||||
return self.get_time()
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
keys: Sequence[str],
|
|
||||||
*,
|
|
||||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
|
||||||
time_at_least: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
if group_ids and len(keys) != len(group_ids):
|
|
||||||
raise ValueError("Length of keys must match length of group_ids")
|
|
||||||
for index, key in enumerate(keys):
|
|
||||||
group_id = group_ids[index] if group_ids else None
|
|
||||||
if time_at_least and time_at_least > self.get_time():
|
|
||||||
raise ValueError("time_at_least must be in the past")
|
|
||||||
self.records[key] = {"group_id": group_id, "updated_at": self.get_time()}
|
|
||||||
|
|
||||||
async def aupdate(
|
|
||||||
self,
|
|
||||||
keys: Sequence[str],
|
|
||||||
*,
|
|
||||||
group_ids: Optional[Sequence[Optional[str]]] = None,
|
|
||||||
time_at_least: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
|
|
||||||
|
|
||||||
def exists(self, keys: Sequence[str]) -> List[bool]:
|
|
||||||
return [key in self.records for key in keys]
|
|
||||||
|
|
||||||
async def aexists(self, keys: Sequence[str]) -> List[bool]:
|
|
||||||
return self.exists(keys)
|
|
||||||
|
|
||||||
def list_keys(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
before: Optional[float] = None,
|
|
||||||
after: Optional[float] = None,
|
|
||||||
group_ids: Optional[Sequence[str]] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> List[str]:
|
|
||||||
result = []
|
|
||||||
for key, data in self.records.items():
|
|
||||||
if before and data["updated_at"] >= before:
|
|
||||||
continue
|
|
||||||
if after and data["updated_at"] <= after:
|
|
||||||
continue
|
|
||||||
if group_ids and data["group_id"] not in group_ids:
|
|
||||||
continue
|
|
||||||
result.append(key)
|
|
||||||
if limit:
|
|
||||||
return result[:limit]
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def alist_keys(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
before: Optional[float] = None,
|
|
||||||
after: Optional[float] = None,
|
|
||||||
group_ids: Optional[Sequence[str]] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> List[str]:
|
|
||||||
return self.list_keys(
|
|
||||||
before=before, after=after, group_ids=group_ids, limit=limit
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_keys(self, keys: Sequence[str]) -> None:
|
|
||||||
for key in keys:
|
|
||||||
if key in self.records:
|
|
||||||
del self.records[key]
|
|
||||||
|
|
||||||
async def adelete_keys(self, keys: Sequence[str]) -> None:
|
|
||||||
self.delete_keys(keys)
|
|
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager
|
from langchain_core.indexing import InMemoryRecordManager
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
|
@@ -18,10 +18,9 @@ import pytest_asyncio
|
|||||||
from langchain_core.document_loaders.base import BaseLoader
|
from langchain_core.document_loaders.base import BaseLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.indexing import aindex, index
|
from langchain_core.indexing import InMemoryRecordManager, aindex, index
|
||||||
from langchain_core.indexing.api import _abatch, _HashedDocument
|
from langchain_core.indexing.api import _abatch, _HashedDocument
|
||||||
from langchain_core.vectorstores import VST, VectorStore
|
from langchain_core.vectorstores import VST, VectorStore
|
||||||
from tests.unit_tests.indexing.in_memory import InMemoryRecordManager
|
|
||||||
|
|
||||||
|
|
||||||
class ToyLoader(BaseLoader):
|
class ToyLoader(BaseLoader):
|
||||||
|
@@ -8,5 +8,6 @@ def test_all() -> None:
|
|||||||
"aindex",
|
"aindex",
|
||||||
"index",
|
"index",
|
||||||
"IndexingResult",
|
"IndexingResult",
|
||||||
|
"InMemoryRecordManager",
|
||||||
"RecordManager",
|
"RecordManager",
|
||||||
]
|
]
|
||||||
|
Reference in New Issue
Block a user