community: Chroma Adding create_collection_if_not_exists flag to Chroma constructor (#21420)

- **Description:** Adds the ability to either `get_or_create` or simply
`get_collection`. This is useful when dealing with read-only Chroma
instances where users are constraint to using `get_collection`. Targeted
at Http/CloudClients mostly.
- **Issue:** chroma-core/chroma#2163
- **Dependencies:** N/A
- **Twitter handle:** `@t_azarov`




| Collection Exists | create_collection_if_not_exists | Outcome | test |

|-------------------|---------------------------------|----------------------------------------------------------------|----------------------------------------------------------|
| True | False | No errors, collection state unchanged |
`test_create_collection_if_not_exist_false_existing` |
| True | True | No errors, collection state unchanged |
`test_create_collection_if_not_exist_true_existing` |
| False | False | Error, `get_collection()` fails |
`test_create_collection_if_not_exist_false_non_existing` |
| False | True | No errors, `get_or_create_collection()` creates the
collection | `test_create_collection_if_not_exist_true_non_existing` |
This commit is contained in:
Trayan Azarov 2024-05-09 18:45:10 +03:00 committed by GitHub
parent 3bb9bec314
commit ba7d53689c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 11 deletions

View File

@ -60,7 +60,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
Y = np.array(Y) Y = np.array(Y)
if X.shape[1] != Y.shape[1]: if X.shape[1] != Y.shape[1]:
raise ValueError( raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} " "Number of columns in X and Y must be the same. X has shape"
f"{X.shape} "
f"and Y has shape {Y.shape}." f"and Y has shape {Y.shape}."
) )
@ -133,6 +134,7 @@ class Chroma(VectorStore):
collection_metadata: Optional[Dict] = None, collection_metadata: Optional[Dict] = None,
client: Optional[chromadb.ClientAPI] = None, client: Optional[chromadb.ClientAPI] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None, relevance_score_fn: Optional[Callable[[float], float]] = None,
create_collection_if_not_exists: Optional[bool] = True,
) -> None: ) -> None:
"""Initialize with a Chroma client.""" """Initialize with a Chroma client."""
@ -161,11 +163,14 @@ class Chroma(VectorStore):
) )
self._embedding_function = embedding_function self._embedding_function = embedding_function
if create_collection_if_not_exists:
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
name=collection_name, name=collection_name,
embedding_function=None, embedding_function=None,
metadata=collection_metadata, metadata=collection_metadata,
) )
else:
self._collection = self._client.get_collection(name=collection_name)
self.override_relevance_score_fn = relevance_score_fn self.override_relevance_score_fn = relevance_score_fn
@property @property
@ -650,7 +655,8 @@ class Chroma(VectorStore):
""" """
return self.update_documents([document_id], [document]) return self.update_documents([document_id], [document])
def update_documents(self, ids: List[str], documents: List[Document]) -> None: # type: ignore # type: ignore
def update_documents(self, ids: List[str], documents: List[Document]) -> None:
"""Update a document in the collection. """Update a document in the collection.
Args: Args:

View File

@ -1,10 +1,12 @@
"""Test Chroma functionality.""" """Test Chroma functionality."""
import uuid import uuid
from typing import Generator
import chromadb import chromadb
import pytest import pytest
import requests import requests
from chromadb.api.client import SharedSystemClient
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings.fake import FakeEmbeddings as Fak from langchain_core.embeddings.fake import FakeEmbeddings as Fak
@ -15,6 +17,13 @@ from tests.integration_tests.fake_embeddings import (
) )
@pytest.fixture()
def client() -> Generator[chromadb.ClientAPI, None, None]:
SharedSystemClient.clear_system_cache()
client = chromadb.Client(chromadb.config.Settings())
yield client
def test_chroma() -> None: def test_chroma() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]
@ -271,10 +280,7 @@ def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
] ]
def test_init_from_client() -> None: def test_init_from_client(client: chromadb.ClientAPI) -> None:
import chromadb
client = chromadb.Client(chromadb.config.Settings())
Chroma(client=client) Chroma(client=client)
@ -414,3 +420,72 @@ def test_chroma_legacy_batching() -> None:
) )
db.delete_collection() db.delete_collection()
def test_create_collection_if_not_exist_default() -> None:
"""Tests existing behaviour without the new create_collection_if_not_exists flag."""
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
)
assert docsearch._client.get_collection("test_collection") is not None
docsearch.delete_collection()
def test_create_collection_if_not_exist_true_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=True and collection already existing."""
client.create_collection("test_collection")
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=True,
)
assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()
def test_create_collection_if_not_exist_false_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=False and collection already existing."""
client.create_collection("test_collection")
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=False,
)
assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()
def test_create_collection_if_not_exist_false_non_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=False and collection not-existing,
should raise."""
with pytest.raises(Exception, match="does not exist"):
Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=False,
)
def test_create_collection_if_not_exist_true_non_existing(
client: chromadb.ClientAPI,
) -> None:
"""Tests create_collection_if_not_exists=True and collection non-existing. ."""
vectorstore = Chroma(
client=client,
collection_name="test_collection",
embedding_function=FakeEmbeddings(),
create_collection_if_not_exists=True,
)
assert vectorstore._client.get_collection("test_collection") is not None
vectorstore.delete_collection()