diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 35146fdcc76..30bc3fd8764 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -53,6 +53,17 @@ def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: ] +def _results_to_docs_and_vectors(results: Any) -> List[Tuple[Document, np.ndarray]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["embeddings"][0], + ) + ] + + Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] @@ -687,6 +698,51 @@ class Chroma(VectorStore): return _results_to_docs_and_scores(results) + def similarity_search_with_vectors( + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, np.ndarray]]: + """Run similarity search with Chroma with vectors. + + Args: + query: Query text to search for. + k: Number of results to return. Defaults to 4. + filter: Filter by metadata. Defaults to None. + where_document: dict used to filter by the documents. + E.g. {$contains: {"text": "hello"}}. + kwargs: Additional keyword arguments to pass to Chroma collection query. + + Returns: + List of documents most similar to the query text and + embedding vectors for each. + """ + include = ["documents", "metadatas", "embeddings"] + if self._embedding_function is None: + results = self.__query_collection( + query_texts=[query], + n_results=k, + where=filter, + where_document=where_document, + include=include, + **kwargs, + ) + else: + query_embedding = self._embedding_function.embed_query(query) + results = self.__query_collection( + query_embeddings=[query_embedding], + n_results=k, + where=filter, + where_document=where_document, + include=include, + **kwargs, + ) + + return _results_to_docs_and_vectors(results) + def _select_relevance_score_fn(self) -> Callable[[float], float]: """Select the relevance score function based on collections distance metric. diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 4393d5f339b..d20d8f3d7c9 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -92,6 +92,24 @@ def test_chroma_with_metadatas_with_scores() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +def test_chroma_with_metadatas_with_vectors() -> None: + """Test end to end construction and scored search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + embeddings = ConsistentFakeEmbeddings() + docsearch = Chroma.from_texts( + collection_name="test_collection", + texts=texts, + embedding=embeddings, + metadatas=metadatas, + ) + vec_1 = embeddings.embed_query(texts[0]) + output = docsearch.similarity_search_with_vectors("foo", k=1) + docsearch.delete_collection() + assert output[0][0] == Document(page_content="foo", metadata={"page": "0"}) + assert (output[0][1] == vec_1).all() + + def test_chroma_with_metadatas_with_scores_using_vector() -> None: """Test end to end construction and scored search, using embedding vector.""" texts = ["foo", "bar", "baz"]