langchain[minor], community[minor]: add CrossEncoderReranker with HuggingFaceCrossEncoder and SagemakerEndpointCrossEncoder (#13687)

- **Description:** Support reranking based on cross encoder models
available from HuggingFace.
      - Added `CrossEncoder` schema
- Implemented `HuggingFaceCrossEncoder` and
`SagemakerEndpointCrossEncoder`
- Implemented `CrossEncoderReranker` that performs similar functionality
to `CohereRerank`
- Added `cross-encoder-reranker.ipynb` to demonstrate how to use it.
Please let me know if anything else needs to be done to make it visible
on the table-of-contents navigation bar on the left, or on the card list
on [retrievers documentation
page](https://python.langchain.com/docs/integrations/retrievers).
  - **Issue:** N/A
  - **Dependencies:** None other than the existing ones.

---------

Co-authored-by: Kenny Choe <kchoe@amazon.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Kenneth Choe 2024-03-31 15:51:31 -05:00 committed by GitHub
parent 3f7da03dd8
commit f98d7f7494
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 660 additions and 0 deletions

View File

@ -0,0 +1,273 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fc0db1bc",
"metadata": {},
"source": [
"# Cross Encoder Reranker\n",
"\n",
"This notebook shows how to implement reranker in a retriever with your own cross encoder from [HuggingFace cross encoder models](https://huggingface.co/cross-encoder) or HuggingFace models that implements cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)). `SagemakerEndpointCrossEncoder` enables you to use these HuggingFace models loaded on Sagemaker.\n",
"\n",
"This builds on top of ideas in the [ContextualCompressionRetriever](/docs/modules/data_connection/retrievers/contextual_compression/). Overall structure of this document came from [Cohere Reranker documentation](/docs/integrations/retrievers/cohere-reranker.ipynb).\n",
"\n",
"For more about why cross encoder can be used as reranking mechanism in conjunction with embeddings for better retrieval, refer to [HuggingFace Cross-Encoders documentation](https://www.sbert.net/examples/applications/cross-encoder/README.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b37bd138-4f3c-4d2c-bc4b-be705ce27a09",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install faiss sentence_transformers\n",
"\n",
"# OR (depending on Python version)\n",
"\n",
"#!pip install faiss-cpu sentence_transformers"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "28e8dc12",
"metadata": {},
"outputs": [],
"source": [
"# Helper function for printing docs\n",
"\n",
"\n",
"def pretty_print_docs(docs):\n",
" print(\n",
" f\"\\n{'-' * 100}\\n\".join(\n",
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
" )\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "6fa3d916",
"metadata": {
"tags": []
},
"source": [
"## Set up the base vector store retriever\n",
"Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fbcc58f",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
"from langchain_community.vectorstores import FAISS\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"\n",
"documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
"texts = text_splitter.split_documents(documents)\n",
"embeddingsModel = HuggingFaceEmbeddings(\n",
" model_name=\"sentence-transformers/msmarco-distilbert-dot-v5\"\n",
")\n",
"retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(\n",
" search_kwargs={\"k\": 20}\n",
")\n",
"\n",
"query = \"What is the plan for the economy?\"\n",
"docs = retriever.get_relevant_documents(query)\n",
"pretty_print_docs(docs)"
]
},
{
"cell_type": "markdown",
"id": "b7648612",
"metadata": {},
"source": [
"## Doing reranking with CrossEncoderReranker\n",
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. `CrossEncoderReranker` uses `HuggingFaceCrossEncoder` to rerank the returned results."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "9a658023",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Document 1:\n",
"\n",
"More infrastructure and innovation in America. \n",
"\n",
"More goods moving faster and cheaper in America. \n",
"\n",
"More jobs where you can earn a good living in America. \n",
"\n",
"And instead of relying on foreign supply chains, lets make it in America. \n",
"\n",
"Economists call it “increasing the productive capacity of our economy.” \n",
"\n",
"I call it building a better America. \n",
"\n",
"My plan to fight inflation will lower your costs and lower the deficit.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 2:\n",
"\n",
"Second cut energy costs for families an average of $500 a year by combatting climate change. \n",
"\n",
"Lets provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double Americas clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because youll never have to pay at the gas pump again.\n",
"----------------------------------------------------------------------------------------------------\n",
"Document 3:\n",
"\n",
"Look at cars. \n",
"\n",
"Last year, there werent enough semiconductors to make all the cars that people wanted to buy. \n",
"\n",
"And guess what, prices of automobiles went up. \n",
"\n",
"So—we have a choice. \n",
"\n",
"One way to fight inflation is to drive down wages and make Americans poorer. \n",
"\n",
"I have a better plan to fight inflation. \n",
"\n",
"Lower your costs, not your wages. \n",
"\n",
"Make more cars and semiconductors in America. \n",
"\n",
"More infrastructure and innovation in America. \n",
"\n",
"More goods moving faster and cheaper in America.\n"
]
}
],
"source": [
"from langchain.retrievers import ContextualCompressionRetriever\n",
"from langchain.retrievers.document_compressors import CrossEncoderReranker\n",
"from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n",
"\n",
"model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n",
"compressor = CrossEncoderReranker(model=model, top_n=3)\n",
"compression_retriever = ContextualCompressionRetriever(\n",
" base_compressor=compressor, base_retriever=retriever\n",
")\n",
"\n",
"compressed_docs = compression_retriever.get_relevant_documents(\n",
" \"What is the plan for the economy?\"\n",
")\n",
"pretty_print_docs(compressed_docs)"
]
},
{
"cell_type": "markdown",
"id": "419a2bf3-de4b-4c4d-9a40-4336552f604c",
"metadata": {},
"source": [
"## Uploading HuggingFace model to SageMaker endpoint\n",
"\n",
"Refer to [this article](https://www.philschmid.de/custom-inference-huggingface-sagemaker) for general guideline. Here is a simple `inference.py` for creating an endpoint that works with `SagemakerEndpointCrossEncoder`.\n",
"\n",
"It downloads HuggingFace model on the fly, so you do not need to keep the model artifacts such as `pytorch_model.bin` in your `model.tar.gz`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e579c743-40c3-432f-9483-0982e2808f9a",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import logging\n",
"from typing import List\n",
"\n",
"import torch\n",
"from sagemaker_inference import encoder\n",
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
"\n",
"PAIRS = \"pairs\"\n",
"SCORES = \"scores\"\n",
"\n",
"\n",
"class CrossEncoder:\n",
" def __init__(self) -> None:\n",
" self.device = (\n",
" torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
" )\n",
" logging.info(f\"Using device: {self.device}\")\n",
" model_name = \"BAAI/bge-reranker-base\"\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
" self.model = self.model.to(self.device)\n",
"\n",
" def __call__(self, pairs: List[List[str]]) -> List[float]:\n",
" with torch.inference_mode():\n",
" inputs = self.tokenizer(\n",
" pairs,\n",
" padding=True,\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" max_length=512,\n",
" )\n",
" inputs = inputs.to(self.device)\n",
" scores = (\n",
" self.model(**inputs, return_dict=True)\n",
" .logits.view(\n",
" -1,\n",
" )\n",
" .float()\n",
" )\n",
"\n",
" return scores.detach().cpu().tolist()\n",
"\n",
"\n",
"def model_fn(model_dir: str) -> CrossEncoder:\n",
" try:\n",
" return CrossEncoder()\n",
" except Exception:\n",
" logging.exception(f\"Failed to load model from: {model_dir}\")\n",
" raise\n",
"\n",
"\n",
"def transform_fn(\n",
" cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str\n",
") -> bytes:\n",
" payload = json.loads(input_data)\n",
" model_output = cross_encoder(**payload)\n",
" output = {SCORES: model_output}\n",
" return encoder.encode(output, accept)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,30 @@
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
services.
**Cross encoder models** can be LLMs or not.
**Class hierarchy:**
.. code-block::
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
"""
import logging
from langchain_community.cross_encoders.base import BaseCrossEncoder
from langchain_community.cross_encoders.fake import FakeCrossEncoder
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
from langchain_community.cross_encoders.sagemaker_endpoint import (
SagemakerEndpointCrossEncoder,
)
logger = logging.getLogger(__name__)
__all__ = [
"BaseCrossEncoder",
"FakeCrossEncoder",
"HuggingFaceCrossEncoder",
"SagemakerEndpointCrossEncoder",
]

View File

@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
class BaseCrossEncoder(ABC):
"""Interface for cross encoder models."""
@abstractmethod
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Score pairs' similarity.
Args:
text_pairs: List of pairs of texts.
Returns:
List of scores.
"""

View File

@ -0,0 +1,18 @@
from difflib import SequenceMatcher
from typing import List, Tuple
from langchain_core.pydantic_v1 import BaseModel
from langchain_community.cross_encoders.base import BaseCrossEncoder
class FakeCrossEncoder(BaseCrossEncoder, BaseModel):
"""Fake cross encoder model."""
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
scores = list(
map(
lambda pair: SequenceMatcher(None, pair[0], pair[1]).ratio(), text_pairs
)
)
return scores

View File

@ -0,0 +1,63 @@
from typing import Any, Dict, List, Tuple
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
from langchain_community.cross_encoders.base import BaseCrossEncoder
DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base"
class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
"""HuggingFace cross encoder models.
Example:
.. code-block:: python
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
model_name = "BAAI/bge-reranker-base"
model_kwargs = {'device': 'cpu'}
hf = HuggingFaceCrossEncoder(
model_name=model_name,
model_kwargs=model_kwargs
)
"""
client: Any #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc
self.client = sentence_transformers.CrossEncoder(
self.model_name, **self.model_kwargs
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Compute similarity scores using a HuggingFace transformer model.
Args:
text_pairs: The list of text text_pairs to score the similarity.
Returns:
List of scores, one for each pair.
"""
scores = self.client.predict(text_pairs)
return scores

View File

@ -0,0 +1,151 @@
import json
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.cross_encoders.base import BaseCrossEncoder
class CrossEncoderContentHandler:
"""Content handler for CrossEncoder class."""
content_type = "application/json"
accepts = "application/json"
def transform_input(self, text_pairs: List[Tuple[str, str]]) -> bytes:
input_str = json.dumps({"text_pairs": text_pairs})
return input_str.encode("utf-8")
def transform_output(self, output: Any) -> List[float]:
response_json = json.loads(output.read().decode("utf-8"))
scores = response_json["scores"]
return scores
class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder):
"""SageMaker Inference CrossEncoder endpoint.
To use, you must supply the endpoint name from your deployed
Sagemaker model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
"""
Example:
.. code-block:: python
from langchain.embeddings import SagemakerEndpointCrossEncoder
endpoint_name = (
"my-endpoint-name"
)
region_name = (
"us-west-2"
)
credentials_profile_name = (
"default"
)
se = SagemakerEndpointCrossEncoder(
endpoint_name=endpoint_name,
region_name=region_name,
credentials_profile_name=credentials_profile_name
)
"""
client: Any #: :meta private:
endpoint_name: str = ""
"""The name of the endpoint from the deployed Sagemaker model.
Must be unique within an AWS Region."""
region_name: str = ""
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
credentials_profile_name: Optional[str] = None
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
has either access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""
content_handler: CrossEncoderContentHandler = CrossEncoderContentHandler()
model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model."""
endpoint_kwargs: Optional[Dict] = None
"""Optional attributes passed to the invoke_endpoint
function. See `boto3`_. docs for more info.
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
try:
import boto3
try:
if values["credentials_profile_name"] is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()
values["client"] = session.client(
"sagemaker-runtime", region_name=values["region_name"]
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
return values
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""Call out to SageMaker Inference CrossEncoder endpoint."""
_endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(text_pairs)
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts
# send request
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
return self.content_handler.transform_output(response["Body"])

View File

@ -0,0 +1 @@
"""Test cross encoder integrations."""

View File

@ -0,0 +1,22 @@
"""Test huggingface cross encoders."""
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
def _assert(encoder: HuggingFaceCrossEncoder) -> None:
query = "I love you"
texts = ["I love you", "I like you", "I don't like you", "I hate you"]
output = encoder.score([(query, text) for text in texts])
for i in range(len(texts) - 1):
assert output[i] > output[i + 1]
def test_huggingface_cross_encoder() -> None:
encoder = HuggingFaceCrossEncoder()
_assert(encoder)
def test_huggingface_cross_encoder_with_designated_model_name() -> None:
encoder = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
_assert(encoder)

View File

@ -6,6 +6,9 @@ from langchain.retrievers.document_compressors.chain_filter import (
LLMChainFilter,
)
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
CrossEncoderReranker,
)
from langchain.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
@ -17,5 +20,6 @@ __all__ = [
"LLMChainExtractor",
"LLMChainFilter",
"CohereRerank",
"CrossEncoderReranker",
"FlashrankRerank",
]

View File

@ -0,0 +1,47 @@
from __future__ import annotations
import operator
from typing import Optional, Sequence
from langchain_community.cross_encoders import BaseCrossEncoder
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra
class CrossEncoderReranker(BaseDocumentCompressor):
"""Document compressor that uses CrossEncoder for reranking."""
model: BaseCrossEncoder
"""CrossEncoder model to use for scoring similarity
between the query and documents."""
top_n: int = 3
"""Number of documents to return."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Rerank documents using CrossEncoder.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
scores = self.model.score([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return [doc for doc, _ in result[: self.top_n]]

View File

@ -0,0 +1,34 @@
"""Integration test for CrossEncoderReranker."""
from typing import List
from langchain_community.cross_encoders import FakeCrossEncoder
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CrossEncoderReranker
def test_rerank() -> None:
texts = [
"aaa1",
"bbb1",
"aaa2",
"bbb2",
"aaa3",
"bbb3",
]
docs = list(map(lambda text: Document(page_content=text), texts))
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "bbb2")
actual = list(map(lambda doc: doc.page_content, actual_docs))
expected_returned = ["bbb2", "bbb1", "bbb3"]
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
assert all([text in actual for text in expected_returned])
assert all([text not in actual for text in expected_not_returned])
assert actual[0] == "bbb2"
def test_rerank_empty() -> None:
docs: List[Document] = []
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "query")
assert len(actual_docs) == 0