mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
Chroma fix mmr (#3897)
Fixes #3628, thanks @derekmoeller for the issue!
This commit is contained in:
parent
3e1cb31f63
commit
2451310975
@ -104,8 +104,17 @@ class Chroma(VectorStore):
|
|||||||
query_embeddings: Optional[List[List[float]]] = None,
|
query_embeddings: Optional[List[List[float]]] = None,
|
||||||
n_results: int = 4,
|
n_results: int = 4,
|
||||||
where: Optional[Dict[str, str]] = None,
|
where: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Query the chroma collection."""
|
"""Query the chroma collection."""
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import chromadb python package. "
|
||||||
|
"Please install it with `pip install chromadb`."
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(n_results, 0, -1):
|
for i in range(n_results, 0, -1):
|
||||||
try:
|
try:
|
||||||
return self._collection.query(
|
return self._collection.query(
|
||||||
@ -113,6 +122,7 @@ class Chroma(VectorStore):
|
|||||||
query_embeddings=query_embeddings,
|
query_embeddings=query_embeddings,
|
||||||
n_results=i,
|
n_results=i,
|
||||||
where=where,
|
where=where,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
except chromadb.errors.NotEnoughElementsException:
|
except chromadb.errors.NotEnoughElementsException:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
@ -126,3 +126,25 @@ def test_chroma_with_persistence() -> None:
|
|||||||
# Persist doesn't need to be called again
|
# Persist doesn't need to be called again
|
||||||
# Data will be automatically persisted on object deletion
|
# Data will be automatically persisted on object deletion
|
||||||
# Or on program exit
|
# Or on program exit
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_mmr() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Chroma.from_texts(
|
||||||
|
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
|
||||||
|
)
|
||||||
|
output = docsearch.max_marginal_relevance_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chroma_mmr_by_vector() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
embeddings = FakeEmbeddings()
|
||||||
|
docsearch = Chroma.from_texts(
|
||||||
|
collection_name="test_collection", texts=texts, embedding=embeddings
|
||||||
|
)
|
||||||
|
embedded_query = embeddings.embed_query("foo")
|
||||||
|
output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
Loading…
Reference in New Issue
Block a user