mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
Chroma batching (#11203)
- **Description:** Chroma >= 0.4.10 added support for batch sizes validation of add/upsert. This batch size is dependent on the SQLite limits of the target system and varies. In this change, for Chroma>=0.4.10 batch splitting was added as the aforementioned validation is starting to surface in the Chroma community (users using LC) - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** @eyurtsev - **Twitter handle:** t_azarov
This commit is contained in:
@@ -558,12 +558,31 @@ class Chroma(VectorStore):
|
|||||||
)
|
)
|
||||||
embeddings = self._embedding_function.embed_documents(text)
|
embeddings = self._embedding_function.embed_documents(text)
|
||||||
|
|
||||||
self._collection.update(
|
if hasattr(
|
||||||
ids=ids,
|
self._collection._client, "max_batch_size"
|
||||||
embeddings=embeddings,
|
): # for Chroma 0.4.10 and above
|
||||||
documents=text,
|
from chromadb.utils.batch_utils import create_batches
|
||||||
metadatas=metadata,
|
|
||||||
)
|
for batch in create_batches(
|
||||||
|
api=self._collection._client,
|
||||||
|
ids=ids,
|
||||||
|
metadatas=metadata,
|
||||||
|
documents=text,
|
||||||
|
embeddings=embeddings,
|
||||||
|
):
|
||||||
|
self._collection.update(
|
||||||
|
ids=batch[0],
|
||||||
|
embeddings=batch[1],
|
||||||
|
documents=batch[3],
|
||||||
|
metadatas=batch[2],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._collection.update(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=text,
|
||||||
|
metadatas=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
@@ -607,7 +626,24 @@ class Chroma(VectorStore):
|
|||||||
collection_metadata=collection_metadata,
|
collection_metadata=collection_metadata,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
if hasattr(
|
||||||
|
chroma_collection._client, "max_batch_size"
|
||||||
|
): # for Chroma 0.4.10 and above
|
||||||
|
from chromadb.utils.batch_utils import create_batches
|
||||||
|
|
||||||
|
for batch in create_batches(
|
||||||
|
api=chroma_collection._client,
|
||||||
|
ids=ids,
|
||||||
|
metadatas=metadatas,
|
||||||
|
documents=texts,
|
||||||
|
):
|
||||||
|
chroma_collection.add_texts(
|
||||||
|
texts=batch[3] if batch[3] else [],
|
||||||
|
metadatas=batch[2] if batch[2] else None,
|
||||||
|
ids=batch[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||||
return chroma_collection
|
return chroma_collection
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@@ -1,7 +1,11 @@
|
|||||||
"""Test Chroma functionality."""
|
"""Test Chroma functionality."""
|
||||||
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings import FakeEmbeddings as Fak
|
||||||
from langchain.vectorstores import Chroma
|
from langchain.vectorstores import Chroma
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
ConsistentFakeEmbeddings,
|
ConsistentFakeEmbeddings,
|
||||||
@@ -301,3 +305,112 @@ def test_chroma_add_documents_mixed_metadata() -> None:
|
|||||||
assert sorted(search, key=lambda d: d.page_content) == sorted(
|
assert sorted(search, key=lambda d: d.page_content) == sorted(
|
||||||
docs, key=lambda d: d.page_content
|
docs, key=lambda d: d.page_content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_api_accessible(url: str) -> bool:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def batch_support_chroma_version() -> bool:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
major, minor, patch = chromadb.__version__.split(".")
|
||||||
|
if int(major) == 0 and int(minor) >= 4 and int(patch) >= 10:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("chromadb")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||||
|
reason="API not accessible",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not batch_support_chroma_version(),
|
||||||
|
reason="ChromaDB version does not support batching",
|
||||||
|
)
|
||||||
|
def test_chroma_large_batch() -> None:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
client = chromadb.HttpClient()
|
||||||
|
embedding_function = Fak(size=255)
|
||||||
|
col = client.get_or_create_collection(
|
||||||
|
"my_collection",
|
||||||
|
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||||
|
)
|
||||||
|
docs = ["This is a test document"] * (client.max_batch_size + 100)
|
||||||
|
Chroma.from_texts(
|
||||||
|
client=client,
|
||||||
|
collection_name=col.name,
|
||||||
|
texts=docs,
|
||||||
|
embedding=embedding_function,
|
||||||
|
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("chromadb")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||||
|
reason="API not accessible",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not batch_support_chroma_version(),
|
||||||
|
reason="ChromaDB version does not support batching",
|
||||||
|
)
|
||||||
|
def test_chroma_large_batch_update() -> None:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
client = chromadb.HttpClient()
|
||||||
|
embedding_function = Fak(size=255)
|
||||||
|
col = client.get_or_create_collection(
|
||||||
|
"my_collection",
|
||||||
|
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||||
|
)
|
||||||
|
docs = ["This is a test document"] * (client.max_batch_size + 100)
|
||||||
|
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
||||||
|
db = Chroma.from_texts(
|
||||||
|
client=client,
|
||||||
|
collection_name=col.name,
|
||||||
|
texts=docs,
|
||||||
|
embedding=embedding_function,
|
||||||
|
ids=ids,
|
||||||
|
)
|
||||||
|
new_docs = [
|
||||||
|
Document(
|
||||||
|
page_content="This is a new test document", metadata={"doc_id": f"{i}"}
|
||||||
|
)
|
||||||
|
for i in range(len(docs) - 10)
|
||||||
|
]
|
||||||
|
new_ids = [_id for _id in ids[: len(new_docs)]]
|
||||||
|
db.update_documents(ids=new_ids, documents=new_docs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("chromadb")
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||||
|
reason="API not accessible",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
batch_support_chroma_version(), reason="ChromaDB version does not support batching"
|
||||||
|
)
|
||||||
|
def test_chroma_legacy_batching() -> None:
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
client = chromadb.HttpClient()
|
||||||
|
embedding_function = Fak(size=255)
|
||||||
|
col = client.get_or_create_collection(
|
||||||
|
"my_collection",
|
||||||
|
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||||
|
)
|
||||||
|
docs = ["This is a test document"] * 100
|
||||||
|
Chroma.from_texts(
|
||||||
|
client=client,
|
||||||
|
collection_name=col.name,
|
||||||
|
texts=docs,
|
||||||
|
embedding=embedding_function,
|
||||||
|
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user