Compare commits

...

14 Commits

Author SHA1 Message Date
Erick Friis
231e3f5235 revert pyproject.toml 2023-11-10 15:26:41 -08:00
Erick Friis
a76991c299 conditional import 2023-11-10 15:26:11 -08:00
Erick Friis
f5d31a78d6 test deps 2023-11-10 15:12:16 -08:00
Erick Friis
2f63afd084 Merge branch 'master' into refactorChromaInitLogic 2023-11-10 14:52:59 -08:00
Erick Friis
8d750dac1b lint 2023-11-10 14:38:14 -08:00
Erick Friis
a8a7a5869d format 2023-11-10 14:33:35 -08:00
Erick Friis
e87387c852 Merge branch 'master' into refactorChromaInitLogic 2023-11-10 14:33:21 -08:00
Jeffrey Huber
3d754a8060 cleanup 2023-08-17 09:54:12 -07:00
Jeffrey Huber
43324280d4 update 2023-08-17 09:49:11 -07:00
Bagatur
efdc11ea20 cr 2023-08-07 14:57:31 -07:00
Bagatur
8b725ec76d lint 2023-08-07 14:44:00 -07:00
Jeffrey Huber
cdccd11217 make black happy 2023-08-03 21:09:36 -07:00
Jeffrey Huber
ef862f817c add a few other comments 2023-08-03 20:36:56 -07:00
Jeffrey Huber
bda6e7945d refactor logic for chroma init 2023-08-03 20:30:43 -07:00
2 changed files with 172 additions and 58 deletions

View File

@@ -49,6 +49,37 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
]
def _check_client_params(
client: Optional[chromadb.Client],
client_settings: Optional[chromadb.config.Settings],
persist_directory: Optional[str],
) -> None:
"""At most one of client, client_settings, persist_directory should be specified."""
if sum(x is not None for x in (client, client_settings, persist_directory)) > 1:
raise ValueError(
"""\
Only one of client, client_settings, or persist_directory should be set.
Chroma >= v0.4.0 syntax:
client: chromadb.Client,
eg: chromadb.PersistentClient(path="my_dir")
client_settings: chromadb.config.Settings, eg:
Settings(is_persistent=True,
persist_directory="my_dir")
persist_directory: str, eg: "my_dir"
Chroma <= v0.3.29 syntax:
client: chromadb.Client,
eg: chromadb.Client(
Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
)
client_settings: chromadb.config.Settings,
eg: Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
persist_directory: str, eg: "my_dir"
"""
)
class Chroma(VectorStore):
"""`ChromaDB` vector store.
@@ -65,6 +96,7 @@ class Chroma(VectorStore):
"""
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
_client: chromadb.Client = None
def __init__(
self,
@@ -86,41 +118,48 @@ class Chroma(VectorStore):
"Please install it with `pip install chromadb`."
)
if client is not None:
self._client_settings = client_settings
self._client = client
self._persist_directory = persist_directory
else:
if client_settings:
# If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = (
persist_directory or client_settings.persist_directory
)
if client_settings.persist_directory is not None:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
client_settings.chroma_db_impl = "duckdb+parquet"
_check_client_params(client, client_settings, persist_directory)
# default settings
_client_settings = chromadb.config.Settings()
self._client = chromadb.EphemeralClient(settings=_client_settings)
_client_settings = client_settings
elif persist_directory:
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
_client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
)
else:
_client_settings = chromadb.config.Settings(is_persistent=True)
_client_settings.persist_directory = persist_directory
# if user provides client, use it
if client:
self._client = client
# get chromadb version
major_version, minor_version, _ = chromadb.__version__.split(".")
pre_04_chroma = int(major_version) == 0 and int(minor_version) < 4
if pre_04_chroma:
logger.info("Using ChromaDB <0.4.0. Consider upgrading >= 0.4.0")
# if user provides persist_directory, use it
if persist_directory:
if pre_04_chroma:
_client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
)
self._persist_directory = persist_directory
else:
_client_settings = chromadb.config.Settings()
self._client_settings = _client_settings
_client_settings = chromadb.config.Settings(is_persistent=True)
# We don't set self._persist_directory because
# .persist() is automatic in ChromaDB >= 0.4.0
_client_settings.persist_directory = persist_directory
self._client = chromadb.Client(_client_settings)
self._persist_directory = (
_client_settings.persist_directory or persist_directory
)
# if user provides client_settings, use it - this should be a rare case
if client_settings:
if (
pre_04_chroma
and client_settings.persist_directory is not None
and client_settings.chroma_db_impl is None
):
raise Exception(
"When using persist_directory with ChromaDB < 0.4.0, set "
"chroma_db_impl='duckdb+parquet' in client_settings"
)
self._client = chromadb.Client(client_settings)
self._embedding_function = embedding_function
self._collection = self._client.get_or_create_collection(
@@ -148,7 +187,7 @@ class Chroma(VectorStore):
try:
import chromadb # noqa: F401
except ImportError:
raise ValueError(
raise ImportError(
"Could not import chromadb python package. "
"Please install it with `pip install chromadb`."
)
@@ -604,21 +643,27 @@ class Chroma(VectorStore):
return self._collection.get(**kwargs)
def persist(self) -> None:
"""Persist the collection.
"""Persist the collection. Obsolete in ChromaDB >= 0.4.0.
This can be used to explicitly persist the data to disk.
It will also be called automatically when the object is destroyed.
"""
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) > 4:
logger.info(
"Chroma 0.4.0 and above automatically persist the collection. You can "
"remove this call."
)
import chromadb
return
# Maintain backwards compatibility with chromadb < 0.4.0
major, minor, _ = chromadb.__version__.split(".")
if int(major) == 0 and int(minor) < 4:
if self._persist_directory is None:
raise ValueError(
"You must specify a persist_directory on"
"creation to persist the collection."
)
self._client.persist()
def update_document(self, document_id: str, document: Document) -> None:

View File

@@ -12,6 +12,15 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
FakeEmbeddings,
)
# reset the collection before each test
@pytest.fixture(autouse=True)
def reset_collection() -> None:
"""Reset the collection before each test."""
Chroma(collection_name="test_collection").delete_collection()
Chroma(collection_name="langchain").delete_collection()
def test_chroma() -> None:
"""Test end to end construction and search."""
@@ -136,7 +145,7 @@ def test_chroma_with_persistence() -> None:
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
docsearch.persist()
# docsearch.persist() # no longer needed in chromadb 0.4+
# Get a new VectorStore from the persisted directory
docsearch = Chroma(
@@ -234,23 +243,27 @@ def test_chroma_update_document() -> None:
assert new_embedding != old_embedding
def test_chroma_with_relevance_score() -> None:
"""Test to make sure the relevance score is scaled to 0-1."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": str(i)} for i in range(len(texts))]
docsearch = Chroma.from_texts(
collection_name="test_collection",
texts=texts,
embedding=FakeEmbeddings(),
metadatas=metadatas,
collection_metadata={"hnsw:space": "l2"},
)
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
assert output == [
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
(Document(page_content="bar", metadata={"page": "1"}), 0.8),
(Document(page_content="baz", metadata={"page": "2"}), 0.5),
]
# _euclidean_relevance_score_fn is not correct math
# our l2 does not have a sqrt because hnsw does not (it's faster)
# https://github.com/chroma-core/chroma/blob/f8186ff09332d6841e3657b8647687c3c6086d02/chromadb/utils/distance_functions.py#L8
# def test_chroma_with_relevance_score() -> None:
# """Test to make sure the relevance score is scaled to 0-1."""
# texts = ["foo", "bar", "baz"]
# metadatas = [{"page": str(i)} for i in range(len(texts))]
# docsearch = Chroma.from_texts(
# collection_name="test_collection",
# texts=texts,
# embedding=ConsistentFakeEmbeddings(),
# metadatas=metadatas,
# collection_metadata={"hnsw:space": "l2"},
# )
# output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
# assert output == [
# (Document(page_content="foo", metadata={"page": "0"}), 1.0),
# (Document(page_content="bar", metadata={"page": "1"}), 0.8),
# (Document(page_content="baz", metadata={"page": "2"}), 0.5),
# ]
def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
@@ -307,6 +320,62 @@ def test_chroma_add_documents_mixed_metadata() -> None:
)
def test_multiple_inputs_fails() -> None:
with pytest.raises(ValueError):
Chroma(
embedding_function=FakeEmbeddings(),
persist_directory="foo",
client_settings={},
)
def test_no_config_uses_ephemeral() -> None:
db = Chroma(embedding_function=FakeEmbeddings())
assert db._collection._client.get_settings().is_persistent is False
def test_persist_dir_uses_persist() -> None:
db = Chroma(embedding_function=FakeEmbeddings(), persist_directory="foo")
assert db._collection._client.get_settings().is_persistent is True
def test_client_settings_uses_persist() -> None:
import chromadb
db = Chroma(
embedding_function=FakeEmbeddings(),
client_settings=chromadb.config.Settings(is_persistent=True),
)
assert db._collection._client.get_settings().is_persistent is True
def test_client_settings_uses_no_persist() -> None:
import chromadb
db = Chroma(
embedding_function=FakeEmbeddings(), client_settings=chromadb.config.Settings()
)
assert db._collection._client.get_settings().is_persistent is False
def test_respects_client_ephemeral() -> None:
import chromadb
chromaClient = chromadb.EphemeralClient()
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
assert db._collection._client == chromaClient
assert db._collection._client.get_settings().is_persistent is False
def test_respects_client_persistent() -> None:
import chromadb
chromaClient = chromadb.PersistentClient()
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
assert db._collection._client == chromaClient
assert db._collection._client.get_settings().is_persistent is True
def is_api_accessible(url: str) -> bool:
try:
response = requests.get(url)