mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 20:58:25 +00:00
community[patch]: Fix in memory vectorstore to take into account ids when adding docs (#21384)
Should respect `ids` if passed
This commit is contained in:
parent
80170da6c5
commit
6a1d61dbf1
@ -38,21 +38,23 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[Sequence[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
ids = []
|
"""Add texts to the store."""
|
||||||
vectors = self.embedding.embed_documents(list(texts))
|
vectors = self.embedding.embed_documents(list(texts))
|
||||||
|
ids_ = []
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
for i, text in enumerate(texts):
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = ids[i] if ids else str(uuid.uuid4())
|
||||||
ids.append(doc_id)
|
ids_.append(doc_id)
|
||||||
self.store[doc_id] = {
|
self.store[doc_id] = {
|
||||||
"id": doc_id,
|
"id": doc_id,
|
||||||
"vector": vectors[i],
|
"vector": vectors[i],
|
||||||
"text": text,
|
"text": text,
|
||||||
"metadata": metadatas[i] if metadatas else {},
|
"metadata": metadatas[i] if metadatas else {},
|
||||||
}
|
}
|
||||||
return ids
|
return ids_
|
||||||
|
|
||||||
async def aadd_texts(
|
async def aadd_texts(
|
||||||
self,
|
self,
|
||||||
@ -185,7 +187,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
store = cls(
|
store = cls(
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
)
|
)
|
||||||
store.add_texts(texts=texts, metadatas=metadatas)
|
store.add_texts(texts=texts, metadatas=metadatas, **kwargs)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -21,6 +21,19 @@ async def test_inmemory() -> None:
|
|||||||
assert output2[0][1] > output2[1][1]
|
assert output2[0][1] > output2[1][1]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_by_ids() -> None:
|
||||||
|
vectorstore = InMemoryVectorStore(embedding=ConsistentFakeEmbeddings())
|
||||||
|
|
||||||
|
# Check sync version
|
||||||
|
ids1 = vectorstore.add_texts(["foo", "bar", "baz"], ids=["1", "2", "3"])
|
||||||
|
assert ids1 == ["1", "2", "3"]
|
||||||
|
assert sorted(vectorstore.store.keys()) == ["1", "2", "3"]
|
||||||
|
|
||||||
|
ids2 = await vectorstore.aadd_texts(["foo", "bar", "baz"], ids=["4", "5", "6"])
|
||||||
|
assert ids2 == ["4", "5", "6"]
|
||||||
|
assert sorted(vectorstore.store.keys()) == ["1", "2", "3", "4", "5", "6"]
|
||||||
|
|
||||||
|
|
||||||
async def test_inmemory_mmr() -> None:
|
async def test_inmemory_mmr() -> None:
|
||||||
texts = ["foo", "foo", "fou", "foy"]
|
texts = ["foo", "foo", "fou", "foy"]
|
||||||
docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())
|
docsearch = await InMemoryVectorStore.afrom_texts(texts, ConsistentFakeEmbeddings())
|
||||||
|
Loading…
Reference in New Issue
Block a user