diff --git a/langchain/vectorstores/chroma.py b/langchain/vectorstores/chroma.py index 7d29dbe5d72..19617947048 100644 --- a/langchain/vectorstores/chroma.py +++ b/langchain/vectorstores/chroma.py @@ -128,6 +128,7 @@ class Chroma(VectorStore): self, query: str, k: int = 4, + find_highest_possible_k: Optional[bool] = True, filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: @@ -136,18 +137,32 @@ class Chroma(VectorStore): Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. + find_highest_possible_k (Optional[bool], True): If True, will iteratively lower k + until there are enough items in the vectorstore to not raise an Error. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Document]: List of documents most similar to the query text. """ - docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) - return [doc for doc, _ in docs_and_scores] + def _similarity_search(k: int): + docs_and_scores = self.similarity_search_with_score(query, k, filter=filter) + return [doc for doc, _ in docs_and_scores] + if not find_highest_possible_k: + return _similarity_search(k=k) + + # Iteratively lower k until an error isn't raised by Chroma + for try_k in range(k, 0, -1): + try: + return _similarity_search(k=try_k) + except chromadb.errors.NotEnoughElementsException: + continue + def similarity_search_by_vector( self, embedding: List[float], k: int = 4, + find_highest_possible_k: Optional[bool] = True, filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: @@ -155,18 +170,32 @@ class Chroma(VectorStore): Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. + find_highest_possible_k (Optional[bool], True): If True, will iteratively lower k + until there are enough items in the vectorstore to not raise an Error. Returns: List of Documents most similar to the query vector. """ - results = self._collection.query( - query_embeddings=embedding, n_results=k, where=filter - ) - return _results_to_docs(results) + def _similarity_search(k: int): + results = self._collection.query( + query_embeddings=embedding, n_results=k, where=filter + ) + return _results_to_docs(results) + + if not find_highest_possible_k: + return _similarity_search(k=k) + + # Iteratively lower k until an error isn't raised by Chroma + for try_k in range(k, 0, -1): + try: + return _similarity_search(k=try_k) + except chromadb.errors.NotEnoughElementsException: + continue def similarity_search_with_score( self, query: str, k: int = 4, + find_highest_possible_k: Optional[bool] = True, filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: @@ -175,23 +204,36 @@ class Chroma(VectorStore): Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. + find_highest_possible_k (Optional[bool], True): If True, will iteratively lower k + until there are enough items in the vectorstore to not raise an Error. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text with distance in float. """ - if self._embedding_function is None: - results = self._collection.query( - query_texts=[query], n_results=k, where=filter - ) - else: - query_embedding = self._embedding_function.embed_query(query) - results = self._collection.query( - query_embeddings=[query_embedding], n_results=k, where=filter - ) + def _similarity_search(k: int): + if self._embedding_function is None: + results = self._collection.query( + query_texts=[query], n_results=k, where=filter + ) + else: + query_embedding = self._embedding_function.embed_query(query) + results = self._collection.query( + query_embeddings=[query_embedding], n_results=k, where=filter + ) - return _results_to_docs_and_scores(results) + return _results_to_docs_and_scores(results) + + if not find_highest_possible_k: + return _similarity_search(k=k) + + # Iteratively lower k until an error isn't raised by Chroma + for try_k in range(k, 0, -1): + try: + return _similarity_search(k=try_k) + except chromadb.errors.NotEnoughElementsException: + continue def max_marginal_relevance_search_by_vector( self,