From c50273684164e01d020aa82ab656073a25992577 Mon Sep 17 00:00:00 2001 From: shibuiwilliam Date: Wed, 14 Feb 2024 14:22:03 +0900 Subject: [PATCH] infra: add test for ensemble retriever to ensure multiple retrievers (#8401) Add tests to ensemble retriever to ensure it works with combination of multiple retrievers --------- Co-authored-by: Bagatur --- .../unit_tests/retrievers/test_ensemble.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index a84ee1f26a7..8b9cbbed115 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -1,6 +1,8 @@ import pytest from langchain_core.documents import Document +from langchain.embeddings import FakeEmbeddings +from langchain.retrievers import KNNRetriever, TFIDFRetriever from langchain.retrievers.bm25 import BM25Retriever from langchain.retrievers.ensemble import EnsembleRetriever @@ -40,3 +42,38 @@ def test_weighted_reciprocal_rank() -> None: result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]]) assert result[0].page_content == "1" assert result[1].page_content == "2" + + +@pytest.mark.requires("rank_bm25", "sklearn") +def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None: + doc_list_a = [ + "I like apples", + "I like oranges", + "Apples and oranges are fruits", + ] + doc_list_b = [ + "I like melons", + "I like pineapples", + "Melons and pineapples are fruits", + ] + doc_list_c = [ + "I like avocados", + "I like strawberries", + "Avocados and strawberries are fruits", + ] + + dummy_retriever = BM25Retriever.from_texts(doc_list_a) + dummy_retriever.k = 1 + tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b) + tfidf_retriever.k = 1 + knn_retriever = KNNRetriever.from_texts( + texts=doc_list_c, embeddings=FakeEmbeddings(size=100) + ) + knn_retriever.k = 1 + + ensemble_retriever = EnsembleRetriever( + retrievers=[dummy_retriever, tfidf_retriever, knn_retriever], + weights=[0.6, 0.3, 0.1], + ) + docs = ensemble_retriever.get_relevant_documents("I like apples") + assert len(docs) == 3