Compare commits

...

14 Commits

Author SHA1 Message Date
Erick Friis
231e3f5235 revert pyproject.toml 2023-11-10 15:26:41 -08:00
Erick Friis
a76991c299 conditional import 2023-11-10 15:26:11 -08:00
Erick Friis
f5d31a78d6 test deps 2023-11-10 15:12:16 -08:00
Erick Friis
2f63afd084 Merge branch 'master' into refactorChromaInitLogic 2023-11-10 14:52:59 -08:00
Erick Friis
8d750dac1b lint 2023-11-10 14:38:14 -08:00
Erick Friis
a8a7a5869d format 2023-11-10 14:33:35 -08:00
Erick Friis
e87387c852 Merge branch 'master' into refactorChromaInitLogic 2023-11-10 14:33:21 -08:00
Jeffrey Huber
3d754a8060 cleanup 2023-08-17 09:54:12 -07:00
Jeffrey Huber
43324280d4 update 2023-08-17 09:49:11 -07:00
Bagatur
efdc11ea20 cr 2023-08-07 14:57:31 -07:00
Bagatur
8b725ec76d lint 2023-08-07 14:44:00 -07:00
Jeffrey Huber
cdccd11217 make black happy 2023-08-03 21:09:36 -07:00
Jeffrey Huber
ef862f817c add a few other comments 2023-08-03 20:36:56 -07:00
Jeffrey Huber
bda6e7945d refactor logic for chroma init 2023-08-03 20:30:43 -07:00
2 changed files with 172 additions and 58 deletions

View File

@@ -49,6 +49,37 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
] ]
def _check_client_params(
client: Optional[chromadb.Client],
client_settings: Optional[chromadb.config.Settings],
persist_directory: Optional[str],
) -> None:
"""At most one of client, client_settings, persist_directory should be specified."""
if sum(x is not None for x in (client, client_settings, persist_directory)) > 1:
raise ValueError(
"""\
Only one of client, client_settings, or persist_directory should be set.
Chroma >= v0.4.0 syntax:
client: chromadb.Client,
eg: chromadb.PersistentClient(path="my_dir")
client_settings: chromadb.config.Settings, eg:
Settings(is_persistent=True,
persist_directory="my_dir")
persist_directory: str, eg: "my_dir"
Chroma <= v0.3.29 syntax:
client: chromadb.Client,
eg: chromadb.Client(
Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
)
client_settings: chromadb.config.Settings,
eg: Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
persist_directory: str, eg: "my_dir"
"""
)
class Chroma(VectorStore): class Chroma(VectorStore):
"""`ChromaDB` vector store. """`ChromaDB` vector store.
@@ -65,6 +96,7 @@ class Chroma(VectorStore):
""" """
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
_client: chromadb.Client = None
def __init__( def __init__(
self, self,
@@ -86,41 +118,48 @@ class Chroma(VectorStore):
"Please install it with `pip install chromadb`." "Please install it with `pip install chromadb`."
) )
if client is not None: _check_client_params(client, client_settings, persist_directory)
self._client_settings = client_settings # default settings
self._client = client _client_settings = chromadb.config.Settings()
self._persist_directory = persist_directory self._client = chromadb.EphemeralClient(settings=_client_settings)
else:
if client_settings:
# If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = (
persist_directory or client_settings.persist_directory
)
if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
client_settings.chroma_db_impl = "duckdb+parquet"
_client_settings = client_settings # if user provides client, use it
elif persist_directory: if client:
# Maintain backwards compatibility with chromadb < 0.4.0 self._client = client
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4: # get chromadb version
_client_settings = chromadb.config.Settings( major_version, minor_version, _ = chromadb.__version__.split(".")
chroma_db_impl="duckdb+parquet", pre_04_chroma = int(major_version) == 0 and int(minor_version) < 4
) if pre_04_chroma:
else: logger.info("Using ChromaDB <0.4.0. Consider upgrading >= 0.4.0")
_client_settings = chromadb.config.Settings(is_persistent=True)
_client_settings.persist_directory = persist_directory # if user provides persist_directory, use it
if persist_directory:
if pre_04_chroma:
_client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
)
self._persist_directory = persist_directory
else: else:
_client_settings = chromadb.config.Settings() _client_settings = chromadb.config.Settings(is_persistent=True)
self._client_settings = _client_settings # We don't set self._persist_directory because
# .persist() is automatic in ChromaDB >= 0.4.0
_client_settings.persist_directory = persist_directory
self._client = chromadb.Client(_client_settings) self._client = chromadb.Client(_client_settings)
self._persist_directory = (
_client_settings.persist_directory or persist_directory # if user provides client_settings, use it - this should be a rare case
) if client_settings:
if (
pre_04_chroma
and client_settings.persist_directory is not None
and client_settings.chroma_db_impl is None
):
raise Exception(
"When using persist_directory with ChromaDB < 0.4.0, set "
"chroma_db_impl='duckdb+parquet' in client_settings"
)
self._client = chromadb.Client(client_settings)
self._embedding_function = embedding_function self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection( self._collection = self._client.get_or_create_collection(
@@ -148,7 +187,7 @@ class Chroma(VectorStore):
try: try:
import chromadb # noqa: F401 import chromadb # noqa: F401
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"Could not import chromadb python package. " "Could not import chromadb python package. "
"Please install it with `pip install chromadb`." "Please install it with `pip install chromadb`."
) )
@@ -604,21 +643,27 @@ class Chroma(VectorStore):
return self._collection.get(**kwargs) return self._collection.get(**kwargs)
def persist(self) -> None: def persist(self) -> None:
"""Persist the collection. """Persist the collection. Obsolete in ChromaDB >= 0.4.0.
This can be used to explicitly persist the data to disk. This can be used to explicitly persist the data to disk.
It will also be called automatically when the object is destroyed. It will also be called automatically when the object is destroyed.
""" """
if self._persist_directory is None: major, minor, _ = chromadb.__version__.split(".")
raise ValueError(
"You must specify a persist_directory on" if int(major) == 0 and int(minor) > 4:
"creation to persist the collection." logger.info(
"Chroma 0.4.0 and above automatically persist the collection. You can "
"remove this call."
) )
import chromadb return
# Maintain backwards compatibility with chromadb < 0.4.0 # Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4: if int(major) == 0 and int(minor) < 4:
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
)
self._client.persist() self._client.persist()
def update_document(self, document_id: str, document: Document) -> None: def update_document(self, document_id: str, document: Document) -> None:

View File

@@ -12,6 +12,15 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
FakeEmbeddings, FakeEmbeddings,
) )
# reset the collection before each test
@pytest.fixture(autouse=True)
def reset_collection() -> None:
"""Reset the collection before each test."""
Chroma(collection_name="test_collection").delete_collection()
Chroma(collection_name="langchain").delete_collection()
def test_chroma() -> None: def test_chroma() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
@@ -136,7 +145,7 @@ def test_chroma_with_persistence() -> None:
output = docsearch.similarity_search("foo", k=1) output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
docsearch.persist() # docsearch.persist() # no longer needed in chromadb 0.4+
# Get a new VectorStore from the persisted directory # Get a new VectorStore from the persisted directory
docsearch = Chroma( docsearch = Chroma(
@@ -234,23 +243,27 @@ def test_chroma_update_document() -> None:
assert new_embedding != old_embedding assert new_embedding != old_embedding
def test_chroma_with_relevance_score() -> None: # _euclidean_relevance_score_fn is not correct math
"""Test to make sure the relevance score is scaled to 0-1.""" # our l2 does not have a sqrt because hnsw does not (it's faster)
texts = ["foo", "bar", "baz"] # https://github.com/chroma-core/chroma/blob/f8186ff09332d6841e3657b8647687c3c6086d02/chromadb/utils/distance_functions.py#L8
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = Chroma.from_texts( # def test_chroma_with_relevance_score() -> None:
collection_name="test_collection", # """Test to make sure the relevance score is scaled to 0-1."""
texts=texts, # texts = ["foo", "bar", "baz"]
embedding=FakeEmbeddings(), # metadatas = [{"page": str(i)} for i in range(len(texts))]
metadatas=metadatas, # docsearch = Chroma.from_texts(
collection_metadata={"hnsw:space": "l2"}, # collection_name="test_collection",
) # texts=texts,
output = docsearch.similarity_search_with_relevance_scores("foo", k=3) # embedding=ConsistentFakeEmbeddings(),
assert output == [ # metadatas=metadatas,
(Document(page_content="foo", metadata={"page": "0"}), 1.0), # collection_metadata={"hnsw:space": "l2"},
(Document(page_content="bar", metadata={"page": "1"}), 0.8), # )
(Document(page_content="baz", metadata={"page": "2"}), 0.5), # output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
] # assert output == [
# (Document(page_content="foo", metadata={"page": "0"}), 1.0),
# (Document(page_content="bar", metadata={"page": "1"}), 0.8),
# (Document(page_content="baz", metadata={"page": "2"}), 0.5),
# ]
def test_chroma_with_relevance_score_custom_normalization_fn() -> None: def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
@@ -307,6 +320,62 @@ def test_chroma_add_documents_mixed_metadata() -> None:
) )
def test_multiple_inputs_fails() -> None:
with pytest.raises(ValueError):
Chroma(
embedding_function=FakeEmbeddings(),
persist_directory="foo",
client_settings={},
)
def test_no_config_uses_ephemeral() -> None:
db = Chroma(embedding_function=FakeEmbeddings())
assert db._collection._client.get_settings().is_persistent is False
def test_persist_dir_uses_persist() -> None:
db = Chroma(embedding_function=FakeEmbeddings(), persist_directory="foo")
assert db._collection._client.get_settings().is_persistent is True
def test_client_settings_uses_persist() -> None:
import chromadb
db = Chroma(
embedding_function=FakeEmbeddings(),
client_settings=chromadb.config.Settings(is_persistent=True),
)
assert db._collection._client.get_settings().is_persistent is True
def test_client_settings_uses_no_persist() -> None:
import chromadb
db = Chroma(
embedding_function=FakeEmbeddings(), client_settings=chromadb.config.Settings()
)
assert db._collection._client.get_settings().is_persistent is False
def test_respects_client_ephemeral() -> None:
import chromadb
chromaClient = chromadb.EphemeralClient()
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
assert db._collection._client == chromaClient
assert db._collection._client.get_settings().is_persistent is False
def test_respects_client_persistent() -> None:
import chromadb
chromaClient = chromadb.PersistentClient()
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
assert db._collection._client == chromaClient
assert db._collection._client.get_settings().is_persistent is True
def is_api_accessible(url: str) -> bool: def is_api_accessible(url: str) -> bool:
try: try:
response = requests.get(url) response = requests.get(url)