mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-30 17:29:56 +00:00
Chroma batching (#11203)
- **Description:** Chroma >= 0.4.10 added support for batch sizes validation of add/upsert. This batch size is dependent on the SQLite limits of the target system and varies. In this change, for Chroma>=0.4.10 batch splitting was added as the aforementioned validation is starting to surface in the Chroma community (users using LC) - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** @eyurtsev - **Twitter handle:** t_azarov
This commit is contained in:
parent
9373b9c004
commit
1fd21ed21c
@ -558,12 +558,31 @@ class Chroma(VectorStore):
|
||||
)
|
||||
embeddings = self._embedding_function.embed_documents(text)
|
||||
|
||||
self._collection.update(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
)
|
||||
if hasattr(
|
||||
self._collection._client, "max_batch_size"
|
||||
): # for Chroma 0.4.10 and above
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
for batch in create_batches(
|
||||
api=self._collection._client,
|
||||
ids=ids,
|
||||
metadatas=metadata,
|
||||
documents=text,
|
||||
embeddings=embeddings,
|
||||
):
|
||||
self._collection.update(
|
||||
ids=batch[0],
|
||||
embeddings=batch[1],
|
||||
documents=batch[3],
|
||||
metadatas=batch[2],
|
||||
)
|
||||
else:
|
||||
self._collection.update(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
@ -607,7 +626,24 @@ class Chroma(VectorStore):
|
||||
collection_metadata=collection_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
if hasattr(
|
||||
chroma_collection._client, "max_batch_size"
|
||||
): # for Chroma 0.4.10 and above
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
for batch in create_batches(
|
||||
api=chroma_collection._client,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=texts,
|
||||
):
|
||||
chroma_collection.add_texts(
|
||||
texts=batch[3] if batch[3] else [],
|
||||
metadatas=batch[2] if batch[2] else None,
|
||||
ids=batch[0],
|
||||
)
|
||||
else:
|
||||
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return chroma_collection
|
||||
|
||||
@classmethod
|
||||
|
@ -1,7 +1,11 @@
|
||||
"""Test Chroma functionality."""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings import FakeEmbeddings as Fak
|
||||
from langchain.vectorstores import Chroma
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
@ -301,3 +305,112 @@ def test_chroma_add_documents_mixed_metadata() -> None:
|
||||
assert sorted(search, key=lambda d: d.page_content) == sorted(
|
||||
docs, key=lambda d: d.page_content
|
||||
)
|
||||
|
||||
|
||||
def is_api_accessible(url: str) -> bool:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def batch_support_chroma_version() -> bool:
|
||||
import chromadb
|
||||
|
||||
major, minor, patch = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) >= 4 and int(patch) >= 10:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.requires("chromadb")
|
||||
@pytest.mark.skipif(
|
||||
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||
reason="API not accessible",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not batch_support_chroma_version(),
|
||||
reason="ChromaDB version does not support batching",
|
||||
)
|
||||
def test_chroma_large_batch() -> None:
|
||||
import chromadb
|
||||
|
||||
client = chromadb.HttpClient()
|
||||
embedding_function = Fak(size=255)
|
||||
col = client.get_or_create_collection(
|
||||
"my_collection",
|
||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||
)
|
||||
docs = ["This is a test document"] * (client.max_batch_size + 100)
|
||||
Chroma.from_texts(
|
||||
client=client,
|
||||
collection_name=col.name,
|
||||
texts=docs,
|
||||
embedding=embedding_function,
|
||||
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("chromadb")
|
||||
@pytest.mark.skipif(
|
||||
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||
reason="API not accessible",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not batch_support_chroma_version(),
|
||||
reason="ChromaDB version does not support batching",
|
||||
)
|
||||
def test_chroma_large_batch_update() -> None:
|
||||
import chromadb
|
||||
|
||||
client = chromadb.HttpClient()
|
||||
embedding_function = Fak(size=255)
|
||||
col = client.get_or_create_collection(
|
||||
"my_collection",
|
||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||
)
|
||||
docs = ["This is a test document"] * (client.max_batch_size + 100)
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
|
||||
db = Chroma.from_texts(
|
||||
client=client,
|
||||
collection_name=col.name,
|
||||
texts=docs,
|
||||
embedding=embedding_function,
|
||||
ids=ids,
|
||||
)
|
||||
new_docs = [
|
||||
Document(
|
||||
page_content="This is a new test document", metadata={"doc_id": f"{i}"}
|
||||
)
|
||||
for i in range(len(docs) - 10)
|
||||
]
|
||||
new_ids = [_id for _id in ids[: len(new_docs)]]
|
||||
db.update_documents(ids=new_ids, documents=new_docs)
|
||||
|
||||
|
||||
@pytest.mark.requires("chromadb")
|
||||
@pytest.mark.skipif(
|
||||
not is_api_accessible("http://localhost:8000/api/v1/heartbeat"),
|
||||
reason="API not accessible",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
batch_support_chroma_version(), reason="ChromaDB version does not support batching"
|
||||
)
|
||||
def test_chroma_legacy_batching() -> None:
|
||||
import chromadb
|
||||
|
||||
client = chromadb.HttpClient()
|
||||
embedding_function = Fak(size=255)
|
||||
col = client.get_or_create_collection(
|
||||
"my_collection",
|
||||
embedding_function=embedding_function.embed_documents, # type: ignore
|
||||
)
|
||||
docs = ["This is a test document"] * 100
|
||||
Chroma.from_texts(
|
||||
client=client,
|
||||
collection_name=col.name,
|
||||
texts=docs,
|
||||
embedding=embedding_function,
|
||||
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user