mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 12:06:43 +00:00
fix empty ids when metadatas is provided (#8127)
Fixes https://github.com/hwchase17/langchain/issues/7865 and https://github.com/hwchase17/langchain/issues/8061 - [x] fixes returning empty ids when metadatas argument is provided @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
62b8b459c6
commit
b7d6e1909c
@ -171,38 +171,52 @@ class Chroma(VectorStore):
|
|||||||
if ids is None:
|
if ids is None:
|
||||||
ids = [str(uuid.uuid1()) for _ in texts]
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
embeddings = None
|
embeddings = None
|
||||||
|
texts = list(texts)
|
||||||
if self._embedding_function is not None:
|
if self._embedding_function is not None:
|
||||||
embeddings = self._embedding_function.embed_documents(list(texts))
|
embeddings = self._embedding_function.embed_documents(texts)
|
||||||
|
|
||||||
if metadatas:
|
if metadatas:
|
||||||
texts = list(texts)
|
# fill metadatas with empty dicts if somebody
|
||||||
empty = []
|
# did not specify metadata for all texts
|
||||||
non_empty = []
|
length_diff = len(texts) - len(metadatas)
|
||||||
for i, m in enumerate(metadatas):
|
if length_diff:
|
||||||
|
metadatas = metadatas + [{}] * length_diff
|
||||||
|
empty_ids = []
|
||||||
|
non_empty_ids = []
|
||||||
|
for idx, m in enumerate(metadatas):
|
||||||
if m:
|
if m:
|
||||||
non_empty.append(i)
|
non_empty_ids.append(idx)
|
||||||
else:
|
else:
|
||||||
empty.append(i)
|
empty_ids.append(idx)
|
||||||
if non_empty:
|
if non_empty_ids:
|
||||||
metadatas = [metadatas[i] for i in non_empty]
|
metadatas = [metadatas[idx] for idx in non_empty_ids]
|
||||||
texts_with_metadatas = [texts[i] for i in non_empty]
|
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
|
||||||
embeddings_with_metadatas = (
|
embeddings_with_metadatas = (
|
||||||
[embeddings[i] for i in non_empty] if embeddings else None
|
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
|
||||||
)
|
)
|
||||||
ids_with_metadata = [ids[i] for i in non_empty]
|
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
|
||||||
self._collection.upsert(
|
self._collection.upsert(
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
embeddings=embeddings_with_metadatas,
|
embeddings=embeddings_with_metadatas,
|
||||||
documents=texts_with_metadatas,
|
documents=texts_with_metadatas,
|
||||||
ids=ids_with_metadata,
|
ids=ids_with_metadata,
|
||||||
)
|
)
|
||||||
|
if empty_ids:
|
||||||
texts = [texts[j] for j in empty]
|
texts_without_metadatas = [texts[j] for j in empty_ids]
|
||||||
embeddings = [embeddings[j] for j in empty] if embeddings else None
|
embeddings_without_metadatas = (
|
||||||
ids = [ids[j] for j in empty]
|
[embeddings[j] for j in empty_ids] if embeddings else None
|
||||||
|
)
|
||||||
if texts:
|
ids_without_metadatas = [ids[j] for j in empty_ids]
|
||||||
self._collection.upsert(embeddings=embeddings, documents=texts, ids=ids)
|
self._collection.upsert(
|
||||||
|
embeddings=embeddings_without_metadatas,
|
||||||
|
documents=texts_without_metadatas,
|
||||||
|
ids=ids_without_metadatas,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._collection.upsert(
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
|
@ -294,7 +294,9 @@ def test_chroma_add_documents_mixed_metadata() -> None:
|
|||||||
Document(page_content="foo"),
|
Document(page_content="foo"),
|
||||||
Document(page_content="bar", metadata={"baz": 1}),
|
Document(page_content="bar", metadata={"baz": 1}),
|
||||||
]
|
]
|
||||||
db.add_documents(docs)
|
ids = ["0", "1"]
|
||||||
|
actual_ids = db.add_documents(docs, ids=ids)
|
||||||
|
assert actual_ids == ids
|
||||||
search = db.similarity_search("foo bar")
|
search = db.similarity_search("foo bar")
|
||||||
assert sorted(search, key=lambda d: d.page_content) == sorted(
|
assert sorted(search, key=lambda d: d.page_content) == sorted(
|
||||||
docs, key=lambda d: d.page_content
|
docs, key=lambda d: d.page_content
|
||||||
|
Loading…
Reference in New Issue
Block a user