mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 00:47:27 +00:00
infra: add test for ensemble retriever to ensure multiple retrievers (#8401)
Add tests to ensemble retriever to ensure it works with combination of multiple retrievers --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5738143d4b
commit
c502736841
@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.embeddings import FakeEmbeddings
|
||||
from langchain.retrievers import KNNRetriever, TFIDFRetriever
|
||||
from langchain.retrievers.bm25 import BM25Retriever
|
||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
|
||||
@ -40,3 +42,38 @@ def test_weighted_reciprocal_rank() -> None:
|
||||
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
||||
assert result[0].page_content == "1"
|
||||
assert result[1].page_content == "2"
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25", "sklearn")
|
||||
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
|
||||
doc_list_a = [
|
||||
"I like apples",
|
||||
"I like oranges",
|
||||
"Apples and oranges are fruits",
|
||||
]
|
||||
doc_list_b = [
|
||||
"I like melons",
|
||||
"I like pineapples",
|
||||
"Melons and pineapples are fruits",
|
||||
]
|
||||
doc_list_c = [
|
||||
"I like avocados",
|
||||
"I like strawberries",
|
||||
"Avocados and strawberries are fruits",
|
||||
]
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
|
||||
dummy_retriever.k = 1
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
|
||||
tfidf_retriever.k = 1
|
||||
knn_retriever = KNNRetriever.from_texts(
|
||||
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
|
||||
)
|
||||
knn_retriever.k = 1
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
|
||||
weights=[0.6, 0.3, 0.1],
|
||||
)
|
||||
docs = ensemble_retriever.get_relevant_documents("I like apples")
|
||||
assert len(docs) == 3
|
||||
|
Loading…
Reference in New Issue
Block a user