mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 05:30:39 +00:00
Add retry logic for ChromaDB (#3372)
Rewrite of #3368 Mainly an issue for when people are just getting started, but still nice to not throw an error if the number of docs is < k. Add a little decorator utility to block mutually exclusive keyword arguments
This commit is contained in:
parent
6b49be9951
commit
b89c258bc5
@ -1,6 +1,6 @@
|
|||||||
"""Generic utility functions."""
|
"""Generic utility functions."""
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
def get_from_dict_or_env(
|
def get_from_dict_or_env(
|
||||||
@ -19,3 +19,28 @@ def get_from_dict_or_env(
|
|||||||
f" `{env_key}` which contains it, or pass"
|
f" `{env_key}` which contains it, or pass"
|
||||||
f" `{key}` as a named parameter."
|
f" `{key}` as a named parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||||
|
"""Validate specified keyword args are mutually exclusive."""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
||||||
|
"""Validate exactly one arg in each group is not None."""
|
||||||
|
counts = [
|
||||||
|
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
||||||
|
for arg_group in arg_groups
|
||||||
|
]
|
||||||
|
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
||||||
|
if invalid_groups:
|
||||||
|
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one argument in each of the following"
|
||||||
|
" groups must be defined:"
|
||||||
|
f" {', '.join(invalid_group_names)}"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
@ -9,6 +9,7 @@ import numpy as np
|
|||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import xor_args
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
@ -96,6 +97,32 @@ class Chroma(VectorStore):
|
|||||||
metadata=collection_metadata,
|
metadata=collection_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@xor_args(("query_texts", "query_embeddings"))
|
||||||
|
def __query_collection(
|
||||||
|
self,
|
||||||
|
query_texts: Optional[List[str]] = None,
|
||||||
|
query_embeddings: Optional[List[List[float]]] = None,
|
||||||
|
n_results: int = 4,
|
||||||
|
where: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Query the chroma collection."""
|
||||||
|
for i in range(n_results, 0, -1):
|
||||||
|
try:
|
||||||
|
return self._collection.query(
|
||||||
|
query_texts=query_texts,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_results=n_results,
|
||||||
|
where=where,
|
||||||
|
)
|
||||||
|
except chromadb.errors.NotEnoughElementsException:
|
||||||
|
logger.error(
|
||||||
|
f"Chroma collection {self._collection.name} "
|
||||||
|
f"contains fewer than {i} elements."
|
||||||
|
)
|
||||||
|
raise chromadb.errors.NotEnoughElementsException(
|
||||||
|
f"No documents found for Chroma collection {self._collection.name}"
|
||||||
|
)
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -158,7 +185,7 @@ class Chroma(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents most similar to the query vector.
|
List of Documents most similar to the query vector.
|
||||||
"""
|
"""
|
||||||
results = self._collection.query(
|
results = self.__query_collection(
|
||||||
query_embeddings=embedding, n_results=k, where=filter
|
query_embeddings=embedding, n_results=k, where=filter
|
||||||
)
|
)
|
||||||
return _results_to_docs(results)
|
return _results_to_docs(results)
|
||||||
@ -182,12 +209,12 @@ class Chroma(VectorStore):
|
|||||||
text with distance in float.
|
text with distance in float.
|
||||||
"""
|
"""
|
||||||
if self._embedding_function is None:
|
if self._embedding_function is None:
|
||||||
results = self._collection.query(
|
results = self.__query_collection(
|
||||||
query_texts=[query], n_results=k, where=filter
|
query_texts=[query], n_results=k, where=filter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query_embedding = self._embedding_function.embed_query(query)
|
query_embedding = self._embedding_function.embed_query(query)
|
||||||
results = self._collection.query(
|
results = self.__query_collection(
|
||||||
query_embeddings=[query_embedding], n_results=k, where=filter
|
query_embeddings=[query_embedding], n_results=k, where=filter
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -218,7 +245,7 @@ class Chroma(VectorStore):
|
|||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = self._collection.query(
|
results = self.__query_collection(
|
||||||
query_embeddings=embedding,
|
query_embeddings=embedding,
|
||||||
n_results=fetch_k,
|
n_results=fetch_k,
|
||||||
where=filter,
|
where=filter,
|
||||||
|
Loading…
Reference in New Issue
Block a user