From ba7d53689c5fc6551ac02164ed7f5bab6d9a8c83 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Thu, 9 May 2024 18:45:10 +0300 Subject: [PATCH] community: Chroma Adding create_collection_if_not_exists flag to Chroma constructor (#21420) - **Description:** Adds the ability to either `get_or_create` or simply `get_collection`. This is useful when dealing with read-only Chroma instances where users are constraint to using `get_collection`. Targeted at Http/CloudClients mostly. - **Issue:** chroma-core/chroma#2163 - **Dependencies:** N/A - **Twitter handle:** `@t_azarov` | Collection Exists | create_collection_if_not_exists | Outcome | test | |-------------------|---------------------------------|----------------------------------------------------------------|----------------------------------------------------------| | True | False | No errors, collection state unchanged | `test_create_collection_if_not_exist_false_existing` | | True | True | No errors, collection state unchanged | `test_create_collection_if_not_exist_true_existing` | | False | False | Error, `get_collection()` fails | `test_create_collection_if_not_exist_false_non_existing` | | False | True | No errors, `get_or_create_collection()` creates the collection | `test_create_collection_if_not_exist_true_non_existing` | --- .../chroma/langchain_chroma/vectorstores.py | 20 +++-- .../integration_tests/test_vectorstores.py | 83 ++++++++++++++++++- 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 5e5394507f2..d5cb86d534d 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -60,7 +60,8 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: Y = np.array(Y) if X.shape[1] != Y.shape[1]: raise ValueError( - f"Number of columns in X and Y must be the same. X has shape {X.shape} " + "Number of columns in X and Y must be the same. X has shape" + f"{X.shape} " f"and Y has shape {Y.shape}." ) @@ -133,6 +134,7 @@ class Chroma(VectorStore): collection_metadata: Optional[Dict] = None, client: Optional[chromadb.ClientAPI] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, + create_collection_if_not_exists: Optional[bool] = True, ) -> None: """Initialize with a Chroma client.""" @@ -161,11 +163,14 @@ class Chroma(VectorStore): ) self._embedding_function = embedding_function - self._collection = self._client.get_or_create_collection( - name=collection_name, - embedding_function=None, - metadata=collection_metadata, - ) + if create_collection_if_not_exists: + self._collection = self._client.get_or_create_collection( + name=collection_name, + embedding_function=None, + metadata=collection_metadata, + ) + else: + self._collection = self._client.get_collection(name=collection_name) self.override_relevance_score_fn = relevance_score_fn @property @@ -650,7 +655,8 @@ class Chroma(VectorStore): """ return self.update_documents([document_id], [document]) - def update_documents(self, ids: List[str], documents: List[Document]) -> None: # type: ignore + # type: ignore + def update_documents(self, ids: List[str], documents: List[Document]) -> None: """Update a document in the collection. Args: diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 156a8256215..97018f87337 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -1,10 +1,12 @@ """Test Chroma functionality.""" import uuid +from typing import Generator import chromadb import pytest import requests +from chromadb.api.client import SharedSystemClient from langchain_core.documents import Document from langchain_core.embeddings.fake import FakeEmbeddings as Fak @@ -15,6 +17,13 @@ from tests.integration_tests.fake_embeddings import ( ) +@pytest.fixture() +def client() -> Generator[chromadb.ClientAPI, None, None]: + SharedSystemClient.clear_system_cache() + client = chromadb.Client(chromadb.config.Settings()) + yield client + + def test_chroma() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -271,10 +280,7 @@ def test_chroma_with_relevance_score_custom_normalization_fn() -> None: ] -def test_init_from_client() -> None: - import chromadb - - client = chromadb.Client(chromadb.config.Settings()) +def test_init_from_client(client: chromadb.ClientAPI) -> None: Chroma(client=client) @@ -414,3 +420,72 @@ def test_chroma_legacy_batching() -> None: ) db.delete_collection() + + +def test_create_collection_if_not_exist_default() -> None: + """Tests existing behaviour without the new create_collection_if_not_exists flag.""" + texts = ["foo", "bar", "baz"] + docsearch = Chroma.from_texts( + collection_name="test_collection", texts=texts, embedding=FakeEmbeddings() + ) + assert docsearch._client.get_collection("test_collection") is not None + docsearch.delete_collection() + + +def test_create_collection_if_not_exist_true_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=True and collection already existing.""" + client.create_collection("test_collection") + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=True, + ) + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection() + + +def test_create_collection_if_not_exist_false_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=False and collection already existing.""" + client.create_collection("test_collection") + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=False, + ) + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection() + + +def test_create_collection_if_not_exist_false_non_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=False and collection not-existing, + should raise.""" + with pytest.raises(Exception, match="does not exist"): + Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=False, + ) + + +def test_create_collection_if_not_exist_true_non_existing( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=True and collection non-existing. .""" + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + create_collection_if_not_exists=True, + ) + + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection()