Chroma batching (#11203)

- **Description:** Chroma >= 0.4.10 added support for batch sizes
validation of add/upsert. This batch size is dependent on the SQLite
limits of the target system and varies. In this change, for
Chroma>=0.4.10 batch splitting was added as the aforementioned
validation is starting to surface in the Chroma community (users using
LC)
 - **Issue:** N/A
 - **Dependencies:** N/A
 - **Tag maintainer:** @eyurtsev
 - **Twitter handle:** t_azarov
This commit is contained in:
Trayan Azarov
2023-10-17 23:59:42 +03:00
committed by GitHub
parent 9373b9c004
commit 1fd21ed21c
2 changed files with 156 additions and 7 deletions

View File

@@ -558,12 +558,31 @@ class Chroma(VectorStore):
) )
embeddings = self._embedding_function.embed_documents(text) embeddings = self._embedding_function.embed_documents(text)
self._collection.update( if hasattr(
ids=ids, self._collection._client, "max_batch_size"
embeddings=embeddings, ): # for Chroma 0.4.10 and above
documents=text, from chromadb.utils.batch_utils import create_batches
metadatas=metadata,
) for batch in create_batches(
api=self._collection._client,
ids=ids,
metadatas=metadata,
documents=text,
embeddings=embeddings,
):
self._collection.update(
ids=batch[0],
embeddings=batch[1],
documents=batch[3],
metadatas=batch[2],
)
else:
self._collection.update(
ids=ids,
embeddings=embeddings,
documents=text,
metadatas=metadata,
)
@classmethod @classmethod
def from_texts( def from_texts(
@@ -607,7 +626,24 @@ class Chroma(VectorStore):
collection_metadata=collection_metadata, collection_metadata=collection_metadata,
**kwargs, **kwargs,
) )
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) if hasattr(
chroma_collection._client, "max_batch_size"
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=chroma_collection._client,
ids=ids,
metadatas=metadatas,
documents=texts,
):
chroma_collection.add_texts(
texts=batch[3] if batch[3] else [],
metadatas=batch[2] if batch[2] else None,
ids=batch[0],
)
else:
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection return chroma_collection
@classmethod @classmethod

View File

@@ -1,7 +1,11 @@
"""Test Chroma functionality.""" """Test Chroma functionality."""
import uuid
import pytest import pytest
import requests
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings import FakeEmbeddings as Fak
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
from tests.integration_tests.vectorstores.fake_embeddings import ( from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings, ConsistentFakeEmbeddings,
@@ -301,3 +305,112 @@ def test_chroma_add_documents_mixed_metadata() -> None:
assert sorted(search, key=lambda d: d.page_content) == sorted( assert sorted(search, key=lambda d: d.page_content) == sorted(
docs, key=lambda d: d.page_content docs, key=lambda d: d.page_content
) )
def is_api_accessible(url: str) -> bool:
try:
response = requests.get(url)
return response.status_code == 200
except Exception:
return False
def batch_support_chroma_version() -> bool:
import chromadb
major, minor, patch = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) >= 4 and int(patch) >= 10:
return True
return False
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
not batch_support_chroma_version(),
reason="ChromaDB version does not support batching",
)
def test_chroma_large_batch() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100)
Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
not batch_support_chroma_version(),
reason="ChromaDB version does not support batching",
)
def test_chroma_large_batch_update() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100)
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=ids,
)
new_docs = [
Document(
page_content="This is a new test document", metadata={"doc_id": f"{i}"}
)
for i in range(len(docs) - 10)
]
new_ids = [_id for _id in ids[: len(new_docs)]]
db.update_documents(ids=new_ids, documents=new_docs)
@pytest.mark.requires("chromadb")
@pytest.mark.skipif(
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
reason="API not accessible",
)
@pytest.mark.skipif(
batch_support_chroma_version(), reason="ChromaDB version does not support batching"
)
def test_chroma_legacy_batching() -> None:
import chromadb
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
)
docs = ["This is a test document"] * 100
Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)