langchain[minor], community[minor]: add CrossEncoderReranker with HuggingFaceCrossEncoder and SagemakerEndpointCrossEncoder (#13687)

- **Description:** Support reranking based on cross encoder models
available from HuggingFace.
      - Added `CrossEncoder` schema
- Implemented `HuggingFaceCrossEncoder` and
`SagemakerEndpointCrossEncoder`
- Implemented `CrossEncoderReranker` that performs similar functionality
to `CohereRerank`
- Added `cross-encoder-reranker.ipynb` to demonstrate how to use it.
Please let me know if anything else needs to be done to make it visible
on the table-of-contents navigation bar on the left, or on the card list
on [retrievers documentation
page](https://python.langchain.com/docs/integrations/retrievers).
  - **Issue:** N/A
  - **Dependencies:** None other than the existing ones.

---------

Co-authored-by: Kenny Choe <kchoe@amazon.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kenneth Choe
2024-03-31 15:51:31 -05:00
committed by GitHub
parent 3f7da03dd8
commit f98d7f7494
11 changed files with 660 additions and 0 deletions

View File

@@ -6,6 +6,9 @@ from langchain.retrievers.document_compressors.chain_filter import (
LLMChainFilter,
)
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
CrossEncoderReranker,
)
from langchain.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
@@ -17,5 +20,6 @@ __all__ = [
"LLMChainExtractor",
"LLMChainFilter",
"CohereRerank",
"CrossEncoderReranker",
"FlashrankRerank",
]

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import operator
from typing import Optional, Sequence
from langchain_community.cross_encoders import BaseCrossEncoder
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra
class CrossEncoderReranker(BaseDocumentCompressor):
"""Document compressor that uses CrossEncoder for reranking."""
model: BaseCrossEncoder
"""CrossEncoder model to use for scoring similarity
between the query and documents."""
top_n: int = 3
"""Number of documents to return."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Rerank documents using CrossEncoder.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
scores = self.model.score([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return [doc for doc, _ in result[: self.top_n]]

View File

@@ -0,0 +1,34 @@
"""Integration test for CrossEncoderReranker."""
from typing import List
from langchain_community.cross_encoders import FakeCrossEncoder
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CrossEncoderReranker
def test_rerank() -> None:
texts = [
"aaa1",
"bbb1",
"aaa2",
"bbb2",
"aaa3",
"bbb3",
]
docs = list(map(lambda text: Document(page_content=text), texts))
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "bbb2")
actual = list(map(lambda doc: doc.page_content, actual_docs))
expected_returned = ["bbb2", "bbb1", "bbb3"]
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
assert all([text in actual for text in expected_returned])
assert all([text not in actual for text in expected_not_returned])
assert actual[0] == "bbb2"
def test_rerank_empty() -> None:
docs: List[Document] = []
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "query")
assert len(actual_docs) == 0