mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
community[patch]: Fix in memory vectorstore to take into account ids when adding docs (#21384)
Should respect `ids` if passed
This commit is contained in:
parent
80170da6c5
commit
6a1d61dbf1
@ -38,21 +38,23 @@ class InMemoryVectorStore(VectorStore):
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
ids = []
|
||||
"""Add texts to the store."""
|
||||
vectors = self.embedding.embed_documents(list(texts))
|
||||
ids_ = []
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
doc_id = str(uuid.uuid4())
|
||||
ids.append(doc_id)
|
||||
doc_id = ids[i] if ids else str(uuid.uuid4())
|
||||
ids_.append(doc_id)
|
||||
self.store[doc_id] = {
|
||||
"id": doc_id,
|
||||
"vector": vectors[i],
|
||||
"text": text,
|
||||
"metadata": metadatas[i] if metadatas else {},
|
||||
}
|
||||
return ids
|
||||
return ids_
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
@ -185,7 +187,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
store = cls(
|
||||
embedding=embedding,
|
||||
)
|
||||
store.add_texts(texts=texts, metadatas=metadatas)
|
||||
store.add_texts(texts=texts, metadatas=metadatas, **kwargs)
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
|
@ -21,6 +21,19 @@ async def test_inmemory() -> None:
|
||||
assert output2[0][1] > output2[1][1]
|
||||
|
||||
|
||||
async def test_add_by_ids() -> None:
|
||||
vectorstore = InMemoryVectorStore(embedding=ConsistentFakeEmbeddings())
|
||||
|
||||
# Check sync version
|
||||
ids1 = vectorstore.add_texts(["foo", "bar", "baz"], ids=["1", "2", "3"])
|
||||
assert ids1 == ["1", "2", "3"]
|
||||
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
|
||||
|
||||
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
|
||||
assert ids2 == ["4", "5", "6"]
|
||||
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
|
||||
|
||||
|
||||
async def test_inmemory_mmr() -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())
|
||||
|
Loading…
Reference in New Issue
Block a user