From 6a1d61dbf19de4581b0f8d7582ef4f784b3e2809 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 7 May 2024 15:05:16 -0400 Subject: [PATCH] community[patch]: Fix in memory vectorstore to take into account ids when adding docs (#21384) Should respect `ids` if passed --- .../langchain_community/vectorstores/inmemory.py | 12 +++++++----- .../tests/unit_tests/vectorstores/test_inmemory.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/inmemory.py b/libs/community/langchain_community/vectorstores/inmemory.py index 7b73e0843d1..e440dc9d933 100644 --- a/libs/community/langchain_community/vectorstores/inmemory.py +++ b/libs/community/langchain_community/vectorstores/inmemory.py @@ -38,21 +38,23 @@ class InMemoryVectorStore(VectorStore): self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, + ids: Optional[Sequence[str]] = None, **kwargs: Any, ) -> List[str]: - ids = [] + """Add texts to the store.""" vectors = self.embedding.embed_documents(list(texts)) + ids_ = [] for i, text in enumerate(texts): - doc_id = str(uuid.uuid4()) - ids.append(doc_id) + doc_id = ids[i] if ids else str(uuid.uuid4()) + ids_.append(doc_id) self.store[doc_id] = { "id": doc_id, "vector": vectors[i], "text": text, "metadata": metadatas[i] if metadatas else {}, } - return ids + return ids_ async def aadd_texts( self, @@ -185,7 +187,7 @@ class InMemoryVectorStore(VectorStore): store = cls( embedding=embedding, ) - store.add_texts(texts=texts, metadatas=metadatas) + store.add_texts(texts=texts, metadatas=metadatas, **kwargs) return store @classmethod diff --git a/libs/community/tests/unit_tests/vectorstores/test_inmemory.py b/libs/community/tests/unit_tests/vectorstores/test_inmemory.py index d571d680d20..e7e50082271 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_inmemory.py +++ b/libs/community/tests/unit_tests/vectorstores/test_inmemory.py @@ -21,6 +21,19 @@ async def test_inmemory() -> None: assert output2[0][1] > output2[1][1] +async def test_add_by_ids() -> None: + vectorstore = InMemoryVectorStore(embedding=ConsistentFakeEmbeddings()) + + # Check sync version + ids1 = vectorstore.add_texts(["foo", "bar", "baz"], ids=["1", "2", "3"]) + assert ids1 == ["1", "2", "3"] + assert sorted(vectorstore.store.keys()) == ["1", "2", "3"] + + ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"]) + assert ids2 == ["4", "5", "6"] + assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"] + + async def test_inmemory_mmr() -> None: texts = ["foo", "foo", "fou", "foy"] docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())