mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-21 18:39:57 +00:00
langchain[minor], community[minor]: add CrossEncoderReranker with HuggingFaceCrossEncoder and SagemakerEndpointCrossEncoder (#13687)
- **Description:** Support reranking based on cross encoder models available from HuggingFace. - Added `CrossEncoder` schema - Implemented `HuggingFaceCrossEncoder` and `SagemakerEndpointCrossEncoder` - Implemented `CrossEncoderReranker` that performs similar functionality to `CohereRerank` - Added `cross-encoder-reranker.ipynb` to demonstrate how to use it. Please let me know if anything else needs to be done to make it visible on the table-of-contents navigation bar on the left, or on the card list on [retrievers documentation page](https://python.langchain.com/docs/integrations/retrievers). - **Issue:** N/A - **Dependencies:** None other than the existing ones. --------- Co-authored-by: Kenny Choe <kchoe@amazon.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -6,6 +6,9 @@ from langchain.retrievers.document_compressors.chain_filter import (
|
||||
LLMChainFilter,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
|
||||
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
|
||||
CrossEncoderReranker,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.embeddings_filter import (
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
@@ -17,5 +20,6 @@ __all__ = [
|
||||
"LLMChainExtractor",
|
||||
"LLMChainFilter",
|
||||
"CohereRerank",
|
||||
"CrossEncoderReranker",
|
||||
"FlashrankRerank",
|
||||
]
|
||||
|
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from langchain_community.cross_encoders import BaseCrossEncoder
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
|
||||
|
||||
class CrossEncoderReranker(BaseDocumentCompressor):
|
||||
"""Document compressor that uses CrossEncoder for reranking."""
|
||||
|
||||
model: BaseCrossEncoder
|
||||
"""CrossEncoder model to use for scoring similarity
|
||||
between the query and documents."""
|
||||
top_n: int = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Rerank documents using CrossEncoder.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
scores = self.model.score([(query, doc.page_content) for doc in documents])
|
||||
docs_with_scores = list(zip(documents, scores))
|
||||
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||
return [doc for doc, _ in result[: self.top_n]]
|
@@ -0,0 +1,34 @@
|
||||
"""Integration test for CrossEncoderReranker."""
|
||||
from typing import List
|
||||
|
||||
from langchain_community.cross_encoders import FakeCrossEncoder
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||
|
||||
|
||||
def test_rerank() -> None:
|
||||
texts = [
|
||||
"aaa1",
|
||||
"bbb1",
|
||||
"aaa2",
|
||||
"bbb2",
|
||||
"aaa3",
|
||||
"bbb3",
|
||||
]
|
||||
docs = list(map(lambda text: Document(page_content=text), texts))
|
||||
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
|
||||
actual_docs = compressor.compress_documents(docs, "bbb2")
|
||||
actual = list(map(lambda doc: doc.page_content, actual_docs))
|
||||
expected_returned = ["bbb2", "bbb1", "bbb3"]
|
||||
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
|
||||
assert all([text in actual for text in expected_returned])
|
||||
assert all([text not in actual for text in expected_not_returned])
|
||||
assert actual[0] == "bbb2"
|
||||
|
||||
|
||||
def test_rerank_empty() -> None:
|
||||
docs: List[Document] = []
|
||||
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
|
||||
actual_docs = compressor.compress_documents(docs, "query")
|
||||
assert len(actual_docs) == 0
|
Reference in New Issue
Block a user