mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
langchain[minor], community[minor]: add CrossEncoderReranker with HuggingFaceCrossEncoder and SagemakerEndpointCrossEncoder (#13687)
- **Description:** Support reranking based on cross encoder models available from HuggingFace. - Added `CrossEncoder` schema - Implemented `HuggingFaceCrossEncoder` and `SagemakerEndpointCrossEncoder` - Implemented `CrossEncoderReranker` that performs similar functionality to `CohereRerank` - Added `cross-encoder-reranker.ipynb` to demonstrate how to use it. Please let me know if anything else needs to be done to make it visible on the table-of-contents navigation bar on the left, or on the card list on [retrievers documentation page](https://python.langchain.com/docs/integrations/retrievers). - **Issue:** N/A - **Dependencies:** None other than the existing ones. --------- Co-authored-by: Kenny Choe <kchoe@amazon.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -0,0 +1,273 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fc0db1bc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cross Encoder Reranker\n",
|
||||
"\n",
|
||||
"This notebook shows how to implement reranker in a retriever with your own cross encoder from [HuggingFace cross encoder models](https://huggingface.co/cross-encoder) or HuggingFace models that implements cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)). `SagemakerEndpointCrossEncoder` enables you to use these HuggingFace models loaded on Sagemaker.\n",
|
||||
"\n",
|
||||
"This builds on top of ideas in the [ContextualCompressionRetriever](/docs/modules/data_connection/retrievers/contextual_compression/). Overall structure of this document came from [Cohere Reranker documentation](/docs/integrations/retrievers/cohere-reranker.ipynb).\n",
|
||||
"\n",
|
||||
"For more about why cross encoder can be used as reranking mechanism in conjunction with embeddings for better retrieval, refer to [HuggingFace Cross-Encoders documentation](https://www.sbert.net/examples/applications/cross-encoder/README.html)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b37bd138-4f3c-4d2c-bc4b-be705ce27a09",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#!pip install faiss sentence_transformers\n",
|
||||
"\n",
|
||||
"# OR (depending on Python version)\n",
|
||||
"\n",
|
||||
"#!pip install faiss-cpu sentence_transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "28e8dc12",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Helper function for printing docs\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def pretty_print_docs(docs):\n",
|
||||
" print(\n",
|
||||
" f\"\\n{'-' * 100}\\n\".join(\n",
|
||||
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
|
||||
" )\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6fa3d916",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"## Set up the base vector store retriever\n",
|
||||
"Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9fbcc58f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
|
||||
"from langchain_community.vectorstores import FAISS\n",
|
||||
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
|
||||
"\n",
|
||||
"documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
|
||||
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
|
||||
"texts = text_splitter.split_documents(documents)\n",
|
||||
"embeddingsModel = HuggingFaceEmbeddings(\n",
|
||||
" model_name=\"sentence-transformers/msmarco-distilbert-dot-v5\"\n",
|
||||
")\n",
|
||||
"retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(\n",
|
||||
" search_kwargs={\"k\": 20}\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"query = \"What is the plan for the economy?\"\n",
|
||||
"docs = retriever.get_relevant_documents(query)\n",
|
||||
"pretty_print_docs(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b7648612",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Doing reranking with CrossEncoderReranker\n",
|
||||
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. `CrossEncoderReranker` uses `HuggingFaceCrossEncoder` to rerank the returned results."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "9a658023",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Document 1:\n",
|
||||
"\n",
|
||||
"More infrastructure and innovation in America. \n",
|
||||
"\n",
|
||||
"More goods moving faster and cheaper in America. \n",
|
||||
"\n",
|
||||
"More jobs where you can earn a good living in America. \n",
|
||||
"\n",
|
||||
"And instead of relying on foreign supply chains, let’s make it in America. \n",
|
||||
"\n",
|
||||
"Economists call it “increasing the productive capacity of our economy.” \n",
|
||||
"\n",
|
||||
"I call it building a better America. \n",
|
||||
"\n",
|
||||
"My plan to fight inflation will lower your costs and lower the deficit.\n",
|
||||
"----------------------------------------------------------------------------------------------------\n",
|
||||
"Document 2:\n",
|
||||
"\n",
|
||||
"Second – cut energy costs for families an average of $500 a year by combatting climate change. \n",
|
||||
"\n",
|
||||
"Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again.\n",
|
||||
"----------------------------------------------------------------------------------------------------\n",
|
||||
"Document 3:\n",
|
||||
"\n",
|
||||
"Look at cars. \n",
|
||||
"\n",
|
||||
"Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy. \n",
|
||||
"\n",
|
||||
"And guess what, prices of automobiles went up. \n",
|
||||
"\n",
|
||||
"So—we have a choice. \n",
|
||||
"\n",
|
||||
"One way to fight inflation is to drive down wages and make Americans poorer. \n",
|
||||
"\n",
|
||||
"I have a better plan to fight inflation. \n",
|
||||
"\n",
|
||||
"Lower your costs, not your wages. \n",
|
||||
"\n",
|
||||
"Make more cars and semiconductors in America. \n",
|
||||
"\n",
|
||||
"More infrastructure and innovation in America. \n",
|
||||
"\n",
|
||||
"More goods moving faster and cheaper in America.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.retrievers import ContextualCompressionRetriever\n",
|
||||
"from langchain.retrievers.document_compressors import CrossEncoderReranker\n",
|
||||
"from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n",
|
||||
"\n",
|
||||
"model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n",
|
||||
"compressor = CrossEncoderReranker(model=model, top_n=3)\n",
|
||||
"compression_retriever = ContextualCompressionRetriever(\n",
|
||||
" base_compressor=compressor, base_retriever=retriever\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"compressed_docs = compression_retriever.get_relevant_documents(\n",
|
||||
" \"What is the plan for the economy?\"\n",
|
||||
")\n",
|
||||
"pretty_print_docs(compressed_docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "419a2bf3-de4b-4c4d-9a40-4336552f604c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Uploading HuggingFace model to SageMaker endpoint\n",
|
||||
"\n",
|
||||
"Refer to [this article](https://www.philschmid.de/custom-inference-huggingface-sagemaker) for general guideline. Here is a simple `inference.py` for creating an endpoint that works with `SagemakerEndpointCrossEncoder`.\n",
|
||||
"\n",
|
||||
"It downloads HuggingFace model on the fly, so you do not need to keep the model artifacts such as `pytorch_model.bin` in your `model.tar.gz`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e579c743-40c3-432f-9483-0982e2808f9a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"from sagemaker_inference import encoder\n",
|
||||
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
|
||||
"\n",
|
||||
"PAIRS = \"pairs\"\n",
|
||||
"SCORES = \"scores\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CrossEncoder:\n",
|
||||
" def __init__(self) -> None:\n",
|
||||
" self.device = (\n",
|
||||
" torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
||||
" )\n",
|
||||
" logging.info(f\"Using device: {self.device}\")\n",
|
||||
" model_name = \"BAAI/bge-reranker-base\"\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
|
||||
" self.model = self.model.to(self.device)\n",
|
||||
"\n",
|
||||
" def __call__(self, pairs: List[List[str]]) -> List[float]:\n",
|
||||
" with torch.inference_mode():\n",
|
||||
" inputs = self.tokenizer(\n",
|
||||
" pairs,\n",
|
||||
" padding=True,\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" max_length=512,\n",
|
||||
" )\n",
|
||||
" inputs = inputs.to(self.device)\n",
|
||||
" scores = (\n",
|
||||
" self.model(**inputs, return_dict=True)\n",
|
||||
" .logits.view(\n",
|
||||
" -1,\n",
|
||||
" )\n",
|
||||
" .float()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return scores.detach().cpu().tolist()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def model_fn(model_dir: str) -> CrossEncoder:\n",
|
||||
" try:\n",
|
||||
" return CrossEncoder()\n",
|
||||
" except Exception:\n",
|
||||
" logging.exception(f\"Failed to load model from: {model_dir}\")\n",
|
||||
" raise\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def transform_fn(\n",
|
||||
" cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str\n",
|
||||
") -> bytes:\n",
|
||||
" payload = json.loads(input_data)\n",
|
||||
" model_output = cross_encoder(**payload)\n",
|
||||
" output = {SCORES: model_output}\n",
|
||||
" return encoder.encode(output, accept)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Reference in New Issue
Block a user