diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 9efbff393aa..221820173be 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -163,16 +163,33 @@ class Chroma(VectorStore): ) self._embedding_function = embedding_function + self._chroma_collection: Optional[chromadb.Collection] = None + self._collection_name = collection_name + self._collection_metadata = collection_metadata if create_collection_if_not_exists: - self._collection = self._client.get_or_create_collection( - name=collection_name, - embedding_function=None, - metadata=collection_metadata, - ) + self.__ensure_collection() else: - self._collection = self._client.get_collection(name=collection_name) + self._chroma_collection = self._client.get_collection(name=collection_name) self.override_relevance_score_fn = relevance_score_fn + def __ensure_collection(self) -> None: + """Ensure that the collection exists or create it.""" + self._chroma_collection = self._client.get_or_create_collection( + name=self._collection_name, + embedding_function=None, + metadata=self._collection_metadata, + ) + + @property + def _collection(self) -> chromadb.Collection: + """Returns the underlying Chroma collection or throws an exception.""" + if self._chroma_collection is None: + raise ValueError( + "Chroma collection not initialized. " + "Use `reset_collection` to re-create and initialize the collection. " + ) + return self._chroma_collection + @property def embeddings(self) -> Optional[Embeddings]: return self._embedding_function @@ -607,6 +624,13 @@ class Chroma(VectorStore): def delete_collection(self) -> None: """Delete the collection.""" self._client.delete_collection(self._collection.name) + self._chroma_collection = None + + def reset_collection(self) -> None: + """Resets the collection by deleting the collection + and recreating an empty one.""" + self.delete_collection() + self.__ensure_collection() def get( self, diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 97018f87337..06f09392161 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -489,3 +489,42 @@ def test_create_collection_if_not_exist_true_non_existing( assert vectorstore._client.get_collection("test_collection") is not None vectorstore.delete_collection() + + +def test_collection_none_after_delete( + client: chromadb.ClientAPI, +) -> None: + """Tests create_collection_if_not_exists=True and collection non-existing. .""" + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + ) + + assert vectorstore._client.get_collection("test_collection") is not None + vectorstore.delete_collection() + assert vectorstore._chroma_collection is None + with pytest.raises(Exception, match="Chroma collection not initialized"): + _ = vectorstore._collection + with pytest.raises(Exception, match="does not exist"): + vectorstore._client.get_collection("test_collection") + with pytest.raises(Exception): + vectorstore.similarity_search("foo") + + +def test_reset_collection(client: chromadb.ClientAPI) -> None: + """Tests ensure_collection method.""" + vectorstore = Chroma( + client=client, + collection_name="test_collection", + embedding_function=FakeEmbeddings(), + ) + vectorstore.add_documents([Document(page_content="foo")]) + assert vectorstore._collection.count() == 1 + vectorstore.reset_collection() + assert vectorstore._chroma_collection is not None + assert vectorstore._client.get_collection("test_collection") is not None + assert vectorstore._collection.name == "test_collection" + assert vectorstore._collection.count() == 0 + # Clean up + vectorstore.delete_collection()