From 1fd21ed21cf1dbc2c7a9f0c4511e2cafd5e95eae Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 17 Oct 2023 23:59:42 +0300 Subject: [PATCH] Chroma batching (#11203) - **Description:** Chroma >= 0.4.10 added support for batch sizes validation of add/upsert. This batch size is dependent on the SQLite limits of the target system and varies. In this change, for Chroma>=0.4.10 batch splitting was added as the aforementioned validation is starting to surface in the Chroma community (users using LC) - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** @eyurtsev - **Twitter handle:** t_azarov --- .../langchain/vectorstores/chroma.py | 50 ++++++-- .../vectorstores/test_chroma.py | 113 ++++++++++++++++++ 2 files changed, 156 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/chroma.py b/libs/langchain/langchain/vectorstores/chroma.py index c666e60c642..26b6a1da8c5 100644 --- a/libs/langchain/langchain/vectorstores/chroma.py +++ b/libs/langchain/langchain/vectorstores/chroma.py @@ -558,12 +558,31 @@ class Chroma(VectorStore): ) embeddings = self._embedding_function.embed_documents(text) - self._collection.update( - ids=ids, - embeddings=embeddings, - documents=text, - metadatas=metadata, - ) + if hasattr( + self._collection._client, "max_batch_size" + ): # for Chroma 0.4.10 and above + from chromadb.utils.batch_utils import create_batches + + for batch in create_batches( + api=self._collection._client, + ids=ids, + metadatas=metadata, + documents=text, + embeddings=embeddings, + ): + self._collection.update( + ids=batch[0], + embeddings=batch[1], + documents=batch[3], + metadatas=batch[2], + ) + else: + self._collection.update( + ids=ids, + embeddings=embeddings, + documents=text, + metadatas=metadata, + ) @classmethod def from_texts( @@ -607,7 +626,24 @@ class Chroma(VectorStore): collection_metadata=collection_metadata, **kwargs, ) - chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) + if hasattr( + chroma_collection._client, "max_batch_size" + ): # for Chroma 0.4.10 and above + from chromadb.utils.batch_utils import create_batches + + for batch in create_batches( + api=chroma_collection._client, + ids=ids, + metadatas=metadatas, + documents=texts, + ): + chroma_collection.add_texts( + texts=batch[3] if batch[3] else [], + metadatas=batch[2] if batch[2] else None, + ids=batch[0], + ) + else: + chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) return chroma_collection @classmethod diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py b/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py index 99f7360537d..cde5d3cc95a 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_chroma.py @@ -1,7 +1,11 @@ """Test Chroma functionality.""" +import uuid + import pytest +import requests from langchain.docstore.document import Document +from langchain.embeddings import FakeEmbeddings as Fak from langchain.vectorstores import Chroma from tests.integration_tests.vectorstores.fake_embeddings import ( ConsistentFakeEmbeddings, @@ -301,3 +305,112 @@ def test_chroma_add_documents_mixed_metadata() -> None: assert sorted(search, key=lambda d: d.page_content) == sorted( docs, key=lambda d: d.page_content ) + + +def is_api_accessible(url: str) -> bool: + try: + response = requests.get(url) + return response.status_code == 200 + except Exception: + return False + + +def batch_support_chroma_version() -> bool: + import chromadb + + major, minor, patch = chromadb.__version__.split(".") + if int(major) == 0 and int(minor) >= 4 and int(patch) >= 10: + return True + return False + + +@pytest.mark.requires("chromadb") +@pytest.mark.skipif( + not is_api_accessible("http://localhost:8000/api/v1/heartbeat"), + reason="API not accessible", +) +@pytest.mark.skipif( + not batch_support_chroma_version(), + reason="ChromaDB version does not support batching", +) +def test_chroma_large_batch() -> None: + import chromadb + + client = chromadb.HttpClient() + embedding_function = Fak(size=255) + col = client.get_or_create_collection( + "my_collection", + embedding_function=embedding_function.embed_documents, # type: ignore + ) + docs = ["This is a test document"] * (client.max_batch_size + 100) + Chroma.from_texts( + client=client, + collection_name=col.name, + texts=docs, + embedding=embedding_function, + ids=[str(uuid.uuid4()) for _ in range(len(docs))], + ) + + +@pytest.mark.requires("chromadb") +@pytest.mark.skipif( + not is_api_accessible("http://localhost:8000/api/v1/heartbeat"), + reason="API not accessible", +) +@pytest.mark.skipif( + not batch_support_chroma_version(), + reason="ChromaDB version does not support batching", +) +def test_chroma_large_batch_update() -> None: + import chromadb + + client = chromadb.HttpClient() + embedding_function = Fak(size=255) + col = client.get_or_create_collection( + "my_collection", + embedding_function=embedding_function.embed_documents, # type: ignore + ) + docs = ["This is a test document"] * (client.max_batch_size + 100) + ids = [str(uuid.uuid4()) for _ in range(len(docs))] + db = Chroma.from_texts( + client=client, + collection_name=col.name, + texts=docs, + embedding=embedding_function, + ids=ids, + ) + new_docs = [ + Document( + page_content="This is a new test document", metadata={"doc_id": f"{i}"} + ) + for i in range(len(docs) - 10) + ] + new_ids = [_id for _id in ids[: len(new_docs)]] + db.update_documents(ids=new_ids, documents=new_docs) + + +@pytest.mark.requires("chromadb") +@pytest.mark.skipif( + not is_api_accessible("http://localhost:8000/api/v1/heartbeat"), + reason="API not accessible", +) +@pytest.mark.skipif( + batch_support_chroma_version(), reason="ChromaDB version does not support batching" +) +def test_chroma_legacy_batching() -> None: + import chromadb + + client = chromadb.HttpClient() + embedding_function = Fak(size=255) + col = client.get_or_create_collection( + "my_collection", + embedding_function=embedding_function.embed_documents, # type: ignore + ) + docs = ["This is a test document"] * 100 + Chroma.from_texts( + client=client, + collection_name=col.name, + texts=docs, + embedding=embedding_function, + ids=[str(uuid.uuid4()) for _ in range(len(docs))], + )