mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
14 Commits
sr/another
...
refactorCh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
231e3f5235 | ||
|
|
a76991c299 | ||
|
|
f5d31a78d6 | ||
|
|
2f63afd084 | ||
|
|
8d750dac1b | ||
|
|
a8a7a5869d | ||
|
|
e87387c852 | ||
|
|
3d754a8060 | ||
|
|
43324280d4 | ||
|
|
efdc11ea20 | ||
|
|
8b725ec76d | ||
|
|
cdccd11217 | ||
|
|
ef862f817c | ||
|
|
bda6e7945d |
@@ -49,6 +49,37 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
]
|
||||
|
||||
|
||||
def _check_client_params(
|
||||
client: Optional[chromadb.Client],
|
||||
client_settings: Optional[chromadb.config.Settings],
|
||||
persist_directory: Optional[str],
|
||||
) -> None:
|
||||
"""At most one of client, client_settings, persist_directory should be specified."""
|
||||
if sum(x is not None for x in (client, client_settings, persist_directory)) > 1:
|
||||
raise ValueError(
|
||||
"""\
|
||||
Only one of client, client_settings, or persist_directory should be set.
|
||||
|
||||
Chroma >= v0.4.0 syntax:
|
||||
client: chromadb.Client,
|
||||
eg: chromadb.PersistentClient(path="my_dir")
|
||||
client_settings: chromadb.config.Settings, eg:
|
||||
Settings(is_persistent=True,
|
||||
persist_directory="my_dir")
|
||||
persist_directory: str, eg: "my_dir"
|
||||
|
||||
Chroma <= v0.3.29 syntax:
|
||||
client: chromadb.Client,
|
||||
eg: chromadb.Client(
|
||||
Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
|
||||
)
|
||||
client_settings: chromadb.config.Settings,
|
||||
eg: Settings(chroma_db_impl="duckdb+parquet", persist_directory="my_dir")
|
||||
persist_directory: str, eg: "my_dir"
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class Chroma(VectorStore):
|
||||
"""`ChromaDB` vector store.
|
||||
|
||||
@@ -65,6 +96,7 @@ class Chroma(VectorStore):
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
_client: chromadb.Client = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,41 +118,48 @@ class Chroma(VectorStore):
|
||||
"Please install it with `pip install chromadb`."
|
||||
)
|
||||
|
||||
if client is not None:
|
||||
self._client_settings = client_settings
|
||||
self._client = client
|
||||
self._persist_directory = persist_directory
|
||||
else:
|
||||
if client_settings:
|
||||
# If client_settings is provided with persist_directory specified,
|
||||
# then it is "in-memory and persisting to disk" mode.
|
||||
client_settings.persist_directory = (
|
||||
persist_directory or client_settings.persist_directory
|
||||
)
|
||||
if client_settings.persist_directory is not None:
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
client_settings.chroma_db_impl = "duckdb+parquet"
|
||||
_check_client_params(client, client_settings, persist_directory)
|
||||
# default settings
|
||||
_client_settings = chromadb.config.Settings()
|
||||
self._client = chromadb.EphemeralClient(settings=_client_settings)
|
||||
|
||||
_client_settings = client_settings
|
||||
elif persist_directory:
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
_client_settings = chromadb.config.Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
)
|
||||
else:
|
||||
_client_settings = chromadb.config.Settings(is_persistent=True)
|
||||
_client_settings.persist_directory = persist_directory
|
||||
# if user provides client, use it
|
||||
if client:
|
||||
self._client = client
|
||||
|
||||
# get chromadb version
|
||||
major_version, minor_version, _ = chromadb.__version__.split(".")
|
||||
pre_04_chroma = int(major_version) == 0 and int(minor_version) < 4
|
||||
if pre_04_chroma:
|
||||
logger.info("Using ChromaDB <0.4.0. Consider upgrading >= 0.4.0")
|
||||
|
||||
# if user provides persist_directory, use it
|
||||
if persist_directory:
|
||||
if pre_04_chroma:
|
||||
_client_settings = chromadb.config.Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
)
|
||||
self._persist_directory = persist_directory
|
||||
else:
|
||||
_client_settings = chromadb.config.Settings()
|
||||
self._client_settings = _client_settings
|
||||
_client_settings = chromadb.config.Settings(is_persistent=True)
|
||||
# We don't set self._persist_directory because
|
||||
# .persist() is automatic in ChromaDB >= 0.4.0
|
||||
|
||||
_client_settings.persist_directory = persist_directory
|
||||
self._client = chromadb.Client(_client_settings)
|
||||
self._persist_directory = (
|
||||
_client_settings.persist_directory or persist_directory
|
||||
)
|
||||
|
||||
# if user provides client_settings, use it - this should be a rare case
|
||||
if client_settings:
|
||||
if (
|
||||
pre_04_chroma
|
||||
and client_settings.persist_directory is not None
|
||||
and client_settings.chroma_db_impl is None
|
||||
):
|
||||
raise Exception(
|
||||
"When using persist_directory with ChromaDB < 0.4.0, set "
|
||||
"chroma_db_impl='duckdb+parquet' in client_settings"
|
||||
)
|
||||
self._client = chromadb.Client(client_settings)
|
||||
|
||||
self._embedding_function = embedding_function
|
||||
self._collection = self._client.get_or_create_collection(
|
||||
@@ -148,7 +187,7 @@ class Chroma(VectorStore):
|
||||
try:
|
||||
import chromadb # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
raise ImportError(
|
||||
"Could not import chromadb python package. "
|
||||
"Please install it with `pip install chromadb`."
|
||||
)
|
||||
@@ -604,21 +643,27 @@ class Chroma(VectorStore):
|
||||
return self._collection.get(**kwargs)
|
||||
|
||||
def persist(self) -> None:
|
||||
"""Persist the collection.
|
||||
"""Persist the collection. Obsolete in ChromaDB >= 0.4.0.
|
||||
|
||||
This can be used to explicitly persist the data to disk.
|
||||
It will also be called automatically when the object is destroyed.
|
||||
"""
|
||||
if self._persist_directory is None:
|
||||
raise ValueError(
|
||||
"You must specify a persist_directory on"
|
||||
"creation to persist the collection."
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
|
||||
if int(major) == 0 and int(minor) > 4:
|
||||
logger.info(
|
||||
"Chroma 0.4.0 and above automatically persist the collection. You can "
|
||||
"remove this call."
|
||||
)
|
||||
import chromadb
|
||||
return
|
||||
|
||||
# Maintain backwards compatibility with chromadb < 0.4.0
|
||||
major, minor, _ = chromadb.__version__.split(".")
|
||||
if int(major) == 0 and int(minor) < 4:
|
||||
if self._persist_directory is None:
|
||||
raise ValueError(
|
||||
"You must specify a persist_directory on"
|
||||
"creation to persist the collection."
|
||||
)
|
||||
self._client.persist()
|
||||
|
||||
def update_document(self, document_id: str, document: Document) -> None:
|
||||
|
||||
@@ -12,6 +12,15 @@ from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
# reset the collection before each test
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_collection() -> None:
|
||||
"""Reset the collection before each test."""
|
||||
Chroma(collection_name="test_collection").delete_collection()
|
||||
Chroma(collection_name="langchain").delete_collection()
|
||||
|
||||
|
||||
def test_chroma() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
@@ -136,7 +145,7 @@ def test_chroma_with_persistence() -> None:
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch.persist()
|
||||
# docsearch.persist() # no longer needed in chromadb 0.4+
|
||||
|
||||
# Get a new VectorStore from the persisted directory
|
||||
docsearch = Chroma(
|
||||
@@ -234,23 +243,27 @@ def test_chroma_update_document() -> None:
|
||||
assert new_embedding != old_embedding
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
collection_metadata={"hnsw:space": "l2"},
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.8),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.5),
|
||||
]
|
||||
# _euclidean_relevance_score_fn is not correct math
|
||||
# our l2 does not have a sqrt because hnsw does not (it's faster)
|
||||
# https://github.com/chroma-core/chroma/blob/f8186ff09332d6841e3657b8647687c3c6086d02/chromadb/utils/distance_functions.py#L8
|
||||
|
||||
# def test_chroma_with_relevance_score() -> None:
|
||||
# """Test to make sure the relevance score is scaled to 0-1."""
|
||||
# texts = ["foo", "bar", "baz"]
|
||||
# metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
# docsearch = Chroma.from_texts(
|
||||
# collection_name="test_collection",
|
||||
# texts=texts,
|
||||
# embedding=ConsistentFakeEmbeddings(),
|
||||
# metadatas=metadatas,
|
||||
# collection_metadata={"hnsw:space": "l2"},
|
||||
# )
|
||||
# output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
# assert output == [
|
||||
# (Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
# (Document(page_content="bar", metadata={"page": "1"}), 0.8),
|
||||
# (Document(page_content="baz", metadata={"page": "2"}), 0.5),
|
||||
# ]
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
|
||||
@@ -307,6 +320,62 @@ def test_chroma_add_documents_mixed_metadata() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_inputs_fails() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Chroma(
|
||||
embedding_function=FakeEmbeddings(),
|
||||
persist_directory="foo",
|
||||
client_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_no_config_uses_ephemeral() -> None:
|
||||
db = Chroma(embedding_function=FakeEmbeddings())
|
||||
assert db._collection._client.get_settings().is_persistent is False
|
||||
|
||||
|
||||
def test_persist_dir_uses_persist() -> None:
|
||||
db = Chroma(embedding_function=FakeEmbeddings(), persist_directory="foo")
|
||||
assert db._collection._client.get_settings().is_persistent is True
|
||||
|
||||
|
||||
def test_client_settings_uses_persist() -> None:
|
||||
import chromadb
|
||||
|
||||
db = Chroma(
|
||||
embedding_function=FakeEmbeddings(),
|
||||
client_settings=chromadb.config.Settings(is_persistent=True),
|
||||
)
|
||||
assert db._collection._client.get_settings().is_persistent is True
|
||||
|
||||
|
||||
def test_client_settings_uses_no_persist() -> None:
|
||||
import chromadb
|
||||
|
||||
db = Chroma(
|
||||
embedding_function=FakeEmbeddings(), client_settings=chromadb.config.Settings()
|
||||
)
|
||||
assert db._collection._client.get_settings().is_persistent is False
|
||||
|
||||
|
||||
def test_respects_client_ephemeral() -> None:
|
||||
import chromadb
|
||||
|
||||
chromaClient = chromadb.EphemeralClient()
|
||||
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
|
||||
assert db._collection._client == chromaClient
|
||||
assert db._collection._client.get_settings().is_persistent is False
|
||||
|
||||
|
||||
def test_respects_client_persistent() -> None:
|
||||
import chromadb
|
||||
|
||||
chromaClient = chromadb.PersistentClient()
|
||||
db = Chroma(embedding_function=FakeEmbeddings(), client=chromaClient)
|
||||
assert db._collection._client == chromaClient
|
||||
assert db._collection._client.get_settings().is_persistent is True
|
||||
|
||||
|
||||
def is_api_accessible(url: str) -> bool:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
|
||||
Reference in New Issue
Block a user