Chroma persistence (#1028)

This PR adds persistence to the Chroma vector store.

Users can supply a `persist_directory` with any of the `Chroma` creation
methods. If supplied, the store will be automatically persisted at that
directory.

If a user creates a new `Chroma` instance with the same persistence
directory, it will get loaded up automatically. If they use `from_texts`
or `from_documents` in this way, the documents will be loaded into the
existing store.

There is the chance of some funky behavior if the user passes a
different embedding function from the one used to create the collection
- we will make this easier in future updates. For now, we log a warning.
This commit is contained in:
Anton Troynikov 2023-02-13 21:09:06 -08:00 committed by GitHub
parent 012a6dfb16
commit d43d430d86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 10 deletions

View File

@ -28,11 +28,15 @@ class Chroma(VectorStore):
""" """
def __init__( def __init__(
self, collection_name: str, embedding_function: Optional[Embeddings] = None self,
collection_name: str,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
) -> None: ) -> None:
"""Initialize with Chroma client.""" """Initialize with Chroma client."""
try: try:
import chromadb import chromadb
import chromadb.config
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import chromadb python package. " "Could not import chromadb python package. "
@ -40,17 +44,23 @@ class Chroma(VectorStore):
) )
# TODO: Add support for custom client. For now this is in-memory only. # TODO: Add support for custom client. For now this is in-memory only.
self._client = chromadb.Client() self._client_settings = chromadb.config.Settings()
if persist_directory is not None:
self._client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
)
self._client = chromadb.Client(self._client_settings)
self._embedding_function = embedding_function self._embedding_function = embedding_function
self._persist_directory = persist_directory
# Check if the collection exists, create it if not # Check if the collection exists, create it if not
if collection_name in [col.name for col in self._client.list_collections()]: if collection_name in [col.name for col in self._client.list_collections()]:
self._collection = self._client.get_collection(name=collection_name) self._collection = self._client.get_collection(name=collection_name)
if embedding_function is not None: # TODO: Persist the user's embedding function
logger.warning( logger.warning(
f"Collection {collection_name} already exists," f"Collection {collection_name} already exists,"
" embedding function will not be updated." " Do you have the right embedding function?"
) )
else: else:
self._collection = self._client.create_collection( self._collection = self._client.create_collection(
name=collection_name, name=collection_name,
@ -78,7 +88,12 @@ class Chroma(VectorStore):
# TODO: Handle the case where the user doesn't provide ids on the Collection # TODO: Handle the case where the user doesn't provide ids on the Collection
if ids is None: if ids is None:
ids = [str(uuid.uuid1()) for _ in texts] ids = [str(uuid.uuid1()) for _ in texts]
self._collection.add(metadatas=metadatas, documents=texts, ids=ids) embeddings = None
if self._embedding_function is not None:
embeddings = self._embedding_function.embed_documents(list(texts))
self._collection.add(
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
)
return ids return ids
def similarity_search( def similarity_search(
@ -116,6 +131,23 @@ class Chroma(VectorStore):
] ]
return docs return docs
def delete_collection(self) -> None:
"""Delete the collection."""
self._client.delete_collection(self._collection.name)
def persist(self) -> None:
"""Persist the collection.
This can be used to explicitly persist the data to disk.
It will also be called automatically when the object is destroyed.
"""
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
)
self._client.persist()
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls,
@ -124,12 +156,17 @@ class Chroma(VectorStore):
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
collection_name: str = "langchain", collection_name: str = "langchain",
persist_directory: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Chroma: ) -> Chroma:
"""Create a Chroma vectorstore from a raw documents. """Create a Chroma vectorstore from a raw documents.
If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.
Args: Args:
collection_name (str): Name of the collection to create. collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
documents (List[Document]): List of documents to add. documents (List[Document]): List of documents to add.
embedding (Optional[Embeddings]): Embedding function. Defaults to None. embedding (Optional[Embeddings]): Embedding function. Defaults to None.
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
@ -139,7 +176,9 @@ class Chroma(VectorStore):
Chroma: Chroma vectorstore. Chroma: Chroma vectorstore.
""" """
chroma_collection = cls( chroma_collection = cls(
collection_name=collection_name, embedding_function=embedding collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory,
) )
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids) chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return chroma_collection return chroma_collection
@ -151,12 +190,17 @@ class Chroma(VectorStore):
embedding: Optional[Embeddings] = None, embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
collection_name: str = "langchain", collection_name: str = "langchain",
persist_directory: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Chroma: ) -> Chroma:
"""Create a Chroma vectorstore from a list of documents. """Create a Chroma vectorstore from a list of documents.
If a persist_directory is specified, the collection will be persisted there.
Otherwise, the data will be ephemeral in-memory.
Args: Args:
collection_name (str): Name of the collection to create. collection_name (str): Name of the collection to create.
persist_directory (Optional[str]): Directory to persist the collection.
documents (List[Document]): List of documents to add to the vectorstore. documents (List[Document]): List of documents to add to the vectorstore.
embedding (Optional[Embeddings]): Embedding function. Defaults to None. embedding (Optional[Embeddings]): Embedding function. Defaults to None.
@ -166,9 +210,10 @@ class Chroma(VectorStore):
texts = [doc.page_content for doc in documents] texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents] metadatas = [doc.metadata for doc in documents]
return cls.from_texts( return cls.from_texts(
collection_name=collection_name,
texts=texts, texts=texts,
embedding=embedding, embedding=embedding,
metadatas=metadatas, metadatas=metadatas,
ids=ids, ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
) )

View File

@ -26,3 +26,36 @@ def test_chroma_with_metadatas() -> None:
) )
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})] assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_chroma_with_persistence() -> None:
"""Test end to end construction and search, with persistence."""
chroma_persist_dir = "./tests/persist_dir"
collection_name = "test_collection"
texts = ["foo", "bar", "baz"]
docsearch = Chroma.from_texts(
collection_name=collection_name,
texts=texts,
embedding=FakeEmbeddings(),
persist_directory=chroma_persist_dir,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
docsearch.persist()
# Get a new VectorStore from the persisted directory
docsearch = Chroma(
collection_name=collection_name,
embedding_function=FakeEmbeddings(),
persist_directory=chroma_persist_dir,
)
output = docsearch.similarity_search("foo", k=1)
# Clean up
docsearch.delete_collection()
# Persist doesn't need to be called again
# Data will be automatically persisted on object deletion
# Or on program exit