mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Chroma persistence (#1028)
This PR adds persistence to the Chroma vector store. Users can supply a `persist_directory` with any of the `Chroma` creation methods. If supplied, the store will be automatically persisted at that directory. If a user creates a new `Chroma` instance with the same persistence directory, it will get loaded up automatically. If they use `from_texts` or `from_documents` in this way, the documents will be loaded into the existing store. There is the chance of some funky behavior if the user passes a different embedding function from the one used to create the collection - we will make this easier in future updates. For now, we log a warning.
This commit is contained in:
parent
012a6dfb16
commit
d43d430d86
@ -28,11 +28,15 @@ class Chroma(VectorStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, collection_name: str, embedding_function: Optional[Embeddings] = None
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
embedding_function: Optional[Embeddings] = None,
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize with Chroma client."""
|
"""Initialize with Chroma client."""
|
||||||
try:
|
try:
|
||||||
import chromadb
|
import chromadb
|
||||||
|
import chromadb.config
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import chromadb python package. "
|
"Could not import chromadb python package. "
|
||||||
@ -40,17 +44,23 @@ class Chroma(VectorStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Add support for custom client. For now this is in-memory only.
|
# TODO: Add support for custom client. For now this is in-memory only.
|
||||||
self._client = chromadb.Client()
|
self._client_settings = chromadb.config.Settings()
|
||||||
|
if persist_directory is not None:
|
||||||
|
self._client_settings = chromadb.config.Settings(
|
||||||
|
chroma_db_impl="duckdb+parquet", persist_directory=persist_directory
|
||||||
|
)
|
||||||
|
self._client = chromadb.Client(self._client_settings)
|
||||||
self._embedding_function = embedding_function
|
self._embedding_function = embedding_function
|
||||||
|
self._persist_directory = persist_directory
|
||||||
|
|
||||||
# Check if the collection exists, create it if not
|
# Check if the collection exists, create it if not
|
||||||
if collection_name in [col.name for col in self._client.list_collections()]:
|
if collection_name in [col.name for col in self._client.list_collections()]:
|
||||||
self._collection = self._client.get_collection(name=collection_name)
|
self._collection = self._client.get_collection(name=collection_name)
|
||||||
if embedding_function is not None:
|
# TODO: Persist the user's embedding function
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Collection {collection_name} already exists,"
|
f"Collection {collection_name} already exists,"
|
||||||
" embedding function will not be updated."
|
" Do you have the right embedding function?"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._collection = self._client.create_collection(
|
self._collection = self._client.create_collection(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
@ -78,7 +88,12 @@ class Chroma(VectorStore):
|
|||||||
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
# TODO: Handle the case where the user doesn't provide ids on the Collection
|
||||||
if ids is None:
|
if ids is None:
|
||||||
ids = [str(uuid.uuid1()) for _ in texts]
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
self._collection.add(metadatas=metadatas, documents=texts, ids=ids)
|
embeddings = None
|
||||||
|
if self._embedding_function is not None:
|
||||||
|
embeddings = self._embedding_function.embed_documents(list(texts))
|
||||||
|
self._collection.add(
|
||||||
|
metadatas=metadatas, embeddings=embeddings, documents=texts, ids=ids
|
||||||
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
@ -116,6 +131,23 @@ class Chroma(VectorStore):
|
|||||||
]
|
]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
"""Delete the collection."""
|
||||||
|
self._client.delete_collection(self._collection.name)
|
||||||
|
|
||||||
|
def persist(self) -> None:
|
||||||
|
"""Persist the collection.
|
||||||
|
|
||||||
|
This can be used to explicitly persist the data to disk.
|
||||||
|
It will also be called automatically when the object is destroyed.
|
||||||
|
"""
|
||||||
|
if self._persist_directory is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You must specify a persist_directory on"
|
||||||
|
"creation to persist the collection."
|
||||||
|
)
|
||||||
|
self._client.persist()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls,
|
cls,
|
||||||
@ -124,12 +156,17 @@ class Chroma(VectorStore):
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
collection_name: str = "langchain",
|
collection_name: str = "langchain",
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Chroma:
|
) -> Chroma:
|
||||||
"""Create a Chroma vectorstore from a raw documents.
|
"""Create a Chroma vectorstore from a raw documents.
|
||||||
|
|
||||||
|
If a persist_directory is specified, the collection will be persisted there.
|
||||||
|
Otherwise, the data will be ephemeral in-memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name (str): Name of the collection to create.
|
collection_name (str): Name of the collection to create.
|
||||||
|
persist_directory (Optional[str]): Directory to persist the collection.
|
||||||
documents (List[Document]): List of documents to add.
|
documents (List[Document]): List of documents to add.
|
||||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||||
@ -139,7 +176,9 @@ class Chroma(VectorStore):
|
|||||||
Chroma: Chroma vectorstore.
|
Chroma: Chroma vectorstore.
|
||||||
"""
|
"""
|
||||||
chroma_collection = cls(
|
chroma_collection = cls(
|
||||||
collection_name=collection_name, embedding_function=embedding
|
collection_name=collection_name,
|
||||||
|
embedding_function=embedding,
|
||||||
|
persist_directory=persist_directory,
|
||||||
)
|
)
|
||||||
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||||
return chroma_collection
|
return chroma_collection
|
||||||
@ -151,12 +190,17 @@ class Chroma(VectorStore):
|
|||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
collection_name: str = "langchain",
|
collection_name: str = "langchain",
|
||||||
|
persist_directory: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Chroma:
|
) -> Chroma:
|
||||||
"""Create a Chroma vectorstore from a list of documents.
|
"""Create a Chroma vectorstore from a list of documents.
|
||||||
|
|
||||||
|
If a persist_directory is specified, the collection will be persisted there.
|
||||||
|
Otherwise, the data will be ephemeral in-memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name (str): Name of the collection to create.
|
collection_name (str): Name of the collection to create.
|
||||||
|
persist_directory (Optional[str]): Directory to persist the collection.
|
||||||
documents (List[Document]): List of documents to add to the vectorstore.
|
documents (List[Document]): List of documents to add to the vectorstore.
|
||||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||||
|
|
||||||
@ -166,9 +210,10 @@ class Chroma(VectorStore):
|
|||||||
texts = [doc.page_content for doc in documents]
|
texts = [doc.page_content for doc in documents]
|
||||||
metadatas = [doc.metadata for doc in documents]
|
metadatas = [doc.metadata for doc in documents]
|
||||||
return cls.from_texts(
|
return cls.from_texts(
|
||||||
collection_name=collection_name,
|
|
||||||
texts=texts,
|
texts=texts,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
|
collection_name=collection_name,
|
||||||
|
persist_directory=persist_directory,
|
||||||
)
|
)
|
||||||
|
@ -26,3 +26,36 @@ def test_chroma_with_metadatas() -> None:
|
|||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_with_persistence() -> None:
|
||||||
|
"""Test end to end construction and search, with persistence."""
|
||||||
|
chroma_persist_dir = "./tests/persist_dir"
|
||||||
|
collection_name = "test_collection"
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Chroma.from_texts(
|
||||||
|
collection_name=collection_name,
|
||||||
|
texts=texts,
|
||||||
|
embedding=FakeEmbeddings(),
|
||||||
|
persist_directory=chroma_persist_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
docsearch.persist()
|
||||||
|
|
||||||
|
# Get a new VectorStore from the persisted directory
|
||||||
|
docsearch = Chroma(
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=FakeEmbeddings(),
|
||||||
|
persist_directory=chroma_persist_dir,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
docsearch.delete_collection()
|
||||||
|
|
||||||
|
# Persist doesn't need to be called again
|
||||||
|
# Data will be automatically persisted on object deletion
|
||||||
|
# Or on program exit
|
||||||
|
Loading…
Reference in New Issue
Block a user