diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index a84ee1f26a7..8b9cbbed115 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,6 +1,8 @@ import pytest from langchain_core.documents import Document +from langchain.embeddings import FakeEmbeddings +from langchain.retrievers import KNNRetriever, TFIDFRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever @@ -40,3 +42,38 @@ def test_weighted_reciprocal_rank() -> None: result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]]) assert result[0].page_content == "1" assert result[1].page_content == "2" + + +@pytest.mark.requires("rank_bm25", "sklearn") +def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None: + doc_list_a = [ + "I like apples", + "I like oranges", + "Apples and oranges are fruits", + ] + doc_list_b = [ + "I like melons", + "I like pineapples", + "Melons and pineapples are fruits", + ] + doc_list_c = [ + "I like avocados", + "I like strawberries", + "Avocados and strawberries are fruits", + ] + + dummy_retriever = BM25Retriever.from_texts(doc_list_a) + dummy_retriever.k = 1 + tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b) + tfidf_retriever.k = 1 + knn_retriever = KNNRetriever.from_texts( + texts=doc_list_c, embeddings=FakeEmbeddings(size=100) + ) + knn_retriever.k = 1 + + ensemble_retriever = EnsembleRetriever( + retrievers=[dummy_retriever, tfidf_retriever, knn_retriever], + weights=[0.6, 0.3, 0.1], + ) + docs = ensemble_retriever.get_relevant_documents("I like apples") + assert len(docs) == 3