mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
langchain[minor], community[minor]: add CrossEncoderReranker with HuggingFaceCrossEncoder and SagemakerEndpointCrossEncoder (#13687)
- **Description:** Support reranking based on cross encoder models available from HuggingFace. - Added `CrossEncoder` schema - Implemented `HuggingFaceCrossEncoder` and `SagemakerEndpointCrossEncoder` - Implemented `CrossEncoderReranker` that performs similar functionality to `CohereRerank` - Added `cross-encoder-reranker.ipynb` to demonstrate how to use it. Please let me know if anything else needs to be done to make it visible on the table-of-contents navigation bar on the left, or on the card list on [retrievers documentation page](https://python.langchain.com/docs/integrations/retrievers). - **Issue:** N/A - **Dependencies:** None other than the existing ones. --------- Co-authored-by: Kenny Choe <kchoe@amazon.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3f7da03dd8
commit
f98d7f7494
@ -0,0 +1,273 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fc0db1bc",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cross Encoder Reranker\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to implement reranker in a retriever with your own cross encoder from [HuggingFace cross encoder models](https://huggingface.co/cross-encoder) or HuggingFace models that implements cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)). `SagemakerEndpointCrossEncoder` enables you to use these HuggingFace models loaded on Sagemaker.\n",
|
||||||
|
"\n",
|
||||||
|
"This builds on top of ideas in the [ContextualCompressionRetriever](/docs/modules/data_connection/retrievers/contextual_compression/). Overall structure of this document came from [Cohere Reranker documentation](/docs/integrations/retrievers/cohere-reranker.ipynb).\n",
|
||||||
|
"\n",
|
||||||
|
"For more about why cross encoder can be used as reranking mechanism in conjunction with embeddings for better retrieval, refer to [HuggingFace Cross-Encoders documentation](https://www.sbert.net/examples/applications/cross-encoder/README.html)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b37bd138-4f3c-4d2c-bc4b-be705ce27a09",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#!pip install faiss sentence_transformers\n",
|
||||||
|
"\n",
|
||||||
|
"# OR (depending on Python version)\n",
|
||||||
|
"\n",
|
||||||
|
"#!pip install faiss-cpu sentence_transformers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "28e8dc12",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Helper function for printing docs\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def pretty_print_docs(docs):\n",
|
||||||
|
" print(\n",
|
||||||
|
" f\"\\n{'-' * 100}\\n\".join(\n",
|
||||||
|
" [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n",
|
||||||
|
" )\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6fa3d916",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Set up the base vector store retriever\n",
|
||||||
|
"Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9fbcc58f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"from langchain_community.embeddings import HuggingFaceEmbeddings\n",
|
||||||
|
"from langchain_community.vectorstores import FAISS\n",
|
||||||
|
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
|
||||||
|
"\n",
|
||||||
|
"documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
|
||||||
|
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
|
||||||
|
"texts = text_splitter.split_documents(documents)\n",
|
||||||
|
"embeddingsModel = HuggingFaceEmbeddings(\n",
|
||||||
|
" model_name=\"sentence-transformers/msmarco-distilbert-dot-v5\"\n",
|
||||||
|
")\n",
|
||||||
|
"retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(\n",
|
||||||
|
" search_kwargs={\"k\": 20}\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"What is the plan for the economy?\"\n",
|
||||||
|
"docs = retriever.get_relevant_documents(query)\n",
|
||||||
|
"pretty_print_docs(docs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b7648612",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Doing reranking with CrossEncoderReranker\n",
|
||||||
|
"Now let's wrap our base retriever with a `ContextualCompressionRetriever`. `CrossEncoderReranker` uses `HuggingFaceCrossEncoder` to rerank the returned results."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 31,
|
||||||
|
"id": "9a658023",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Document 1:\n",
|
||||||
|
"\n",
|
||||||
|
"More infrastructure and innovation in America. \n",
|
||||||
|
"\n",
|
||||||
|
"More goods moving faster and cheaper in America. \n",
|
||||||
|
"\n",
|
||||||
|
"More jobs where you can earn a good living in America. \n",
|
||||||
|
"\n",
|
||||||
|
"And instead of relying on foreign supply chains, let’s make it in America. \n",
|
||||||
|
"\n",
|
||||||
|
"Economists call it “increasing the productive capacity of our economy.” \n",
|
||||||
|
"\n",
|
||||||
|
"I call it building a better America. \n",
|
||||||
|
"\n",
|
||||||
|
"My plan to fight inflation will lower your costs and lower the deficit.\n",
|
||||||
|
"----------------------------------------------------------------------------------------------------\n",
|
||||||
|
"Document 2:\n",
|
||||||
|
"\n",
|
||||||
|
"Second – cut energy costs for families an average of $500 a year by combatting climate change. \n",
|
||||||
|
"\n",
|
||||||
|
"Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again.\n",
|
||||||
|
"----------------------------------------------------------------------------------------------------\n",
|
||||||
|
"Document 3:\n",
|
||||||
|
"\n",
|
||||||
|
"Look at cars. \n",
|
||||||
|
"\n",
|
||||||
|
"Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy. \n",
|
||||||
|
"\n",
|
||||||
|
"And guess what, prices of automobiles went up. \n",
|
||||||
|
"\n",
|
||||||
|
"So—we have a choice. \n",
|
||||||
|
"\n",
|
||||||
|
"One way to fight inflation is to drive down wages and make Americans poorer. \n",
|
||||||
|
"\n",
|
||||||
|
"I have a better plan to fight inflation. \n",
|
||||||
|
"\n",
|
||||||
|
"Lower your costs, not your wages. \n",
|
||||||
|
"\n",
|
||||||
|
"Make more cars and semiconductors in America. \n",
|
||||||
|
"\n",
|
||||||
|
"More infrastructure and innovation in America. \n",
|
||||||
|
"\n",
|
||||||
|
"More goods moving faster and cheaper in America.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import ContextualCompressionRetriever\n",
|
||||||
|
"from langchain.retrievers.document_compressors import CrossEncoderReranker\n",
|
||||||
|
"from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n",
|
||||||
|
"\n",
|
||||||
|
"model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n",
|
||||||
|
"compressor = CrossEncoderReranker(model=model, top_n=3)\n",
|
||||||
|
"compression_retriever = ContextualCompressionRetriever(\n",
|
||||||
|
" base_compressor=compressor, base_retriever=retriever\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"compressed_docs = compression_retriever.get_relevant_documents(\n",
|
||||||
|
" \"What is the plan for the economy?\"\n",
|
||||||
|
")\n",
|
||||||
|
"pretty_print_docs(compressed_docs)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "419a2bf3-de4b-4c4d-9a40-4336552f604c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Uploading HuggingFace model to SageMaker endpoint\n",
|
||||||
|
"\n",
|
||||||
|
"Refer to [this article](https://www.philschmid.de/custom-inference-huggingface-sagemaker) for general guideline. Here is a simple `inference.py` for creating an endpoint that works with `SagemakerEndpointCrossEncoder`.\n",
|
||||||
|
"\n",
|
||||||
|
"It downloads HuggingFace model on the fly, so you do not need to keep the model artifacts such as `pytorch_model.bin` in your `model.tar.gz`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e579c743-40c3-432f-9483-0982e2808f9a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import logging\n",
|
||||||
|
"from typing import List\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from sagemaker_inference import encoder\n",
|
||||||
|
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
|
||||||
|
"\n",
|
||||||
|
"PAIRS = \"pairs\"\n",
|
||||||
|
"SCORES = \"scores\"\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class CrossEncoder:\n",
|
||||||
|
" def __init__(self) -> None:\n",
|
||||||
|
" self.device = (\n",
|
||||||
|
" torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
|
||||||
|
" )\n",
|
||||||
|
" logging.info(f\"Using device: {self.device}\")\n",
|
||||||
|
" model_name = \"BAAI/bge-reranker-base\"\n",
|
||||||
|
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||||
|
" self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
|
||||||
|
" self.model = self.model.to(self.device)\n",
|
||||||
|
"\n",
|
||||||
|
" def __call__(self, pairs: List[List[str]]) -> List[float]:\n",
|
||||||
|
" with torch.inference_mode():\n",
|
||||||
|
" inputs = self.tokenizer(\n",
|
||||||
|
" pairs,\n",
|
||||||
|
" padding=True,\n",
|
||||||
|
" truncation=True,\n",
|
||||||
|
" return_tensors=\"pt\",\n",
|
||||||
|
" max_length=512,\n",
|
||||||
|
" )\n",
|
||||||
|
" inputs = inputs.to(self.device)\n",
|
||||||
|
" scores = (\n",
|
||||||
|
" self.model(**inputs, return_dict=True)\n",
|
||||||
|
" .logits.view(\n",
|
||||||
|
" -1,\n",
|
||||||
|
" )\n",
|
||||||
|
" .float()\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" return scores.detach().cpu().tolist()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def model_fn(model_dir: str) -> CrossEncoder:\n",
|
||||||
|
" try:\n",
|
||||||
|
" return CrossEncoder()\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" logging.exception(f\"Failed to load model from: {model_dir}\")\n",
|
||||||
|
" raise\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def transform_fn(\n",
|
||||||
|
" cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str\n",
|
||||||
|
") -> bytes:\n",
|
||||||
|
" payload = json.loads(input_data)\n",
|
||||||
|
" model_output = cross_encoder(**payload)\n",
|
||||||
|
" output = {SCORES: model_output}\n",
|
||||||
|
" return encoder.encode(output, accept)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
"""**Cross encoders** are wrappers around cross encoder models from different APIs and
|
||||||
|
services.
|
||||||
|
|
||||||
|
**Cross encoder models** can be LLMs or not.
|
||||||
|
|
||||||
|
**Class hierarchy:**
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
BaseCrossEncoder --> <name>CrossEncoder # Examples: SagemakerEndpointCrossEncoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||||
|
from langchain_community.cross_encoders.fake import FakeCrossEncoder
|
||||||
|
from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder
|
||||||
|
from langchain_community.cross_encoders.sagemaker_endpoint import (
|
||||||
|
SagemakerEndpointCrossEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseCrossEncoder",
|
||||||
|
"FakeCrossEncoder",
|
||||||
|
"HuggingFaceCrossEncoder",
|
||||||
|
"SagemakerEndpointCrossEncoder",
|
||||||
|
]
|
17
libs/community/langchain_community/cross_encoders/base.py
Normal file
17
libs/community/langchain_community/cross_encoders/base.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCrossEncoder(ABC):
|
||||||
|
"""Interface for cross encoder models."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||||
|
"""Score pairs' similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_pairs: List of pairs of texts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores.
|
||||||
|
"""
|
18
libs/community/langchain_community/cross_encoders/fake.py
Normal file
18
libs/community/langchain_community/cross_encoders/fake.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCrossEncoder(BaseCrossEncoder, BaseModel):
|
||||||
|
"""Fake cross encoder model."""
|
||||||
|
|
||||||
|
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||||
|
scores = list(
|
||||||
|
map(
|
||||||
|
lambda pair: SequenceMatcher(None, pair[0], pair[1]).ratio(), text_pairs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return scores
|
@ -0,0 +1,63 @@
|
|||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, Field
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||||
|
|
||||||
|
DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base"
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||||
|
"""HuggingFace cross encoder models.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||||
|
|
||||||
|
model_name = "BAAI/bge-reranker-base"
|
||||||
|
model_kwargs = {'device': 'cpu'}
|
||||||
|
hf = HuggingFaceCrossEncoder(
|
||||||
|
model_name=model_name,
|
||||||
|
model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model_name: str = DEFAULT_MODEL_NAME
|
||||||
|
"""Model name to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Keyword arguments to pass to the model."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Initialize the sentence_transformer."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
try:
|
||||||
|
import sentence_transformers
|
||||||
|
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import sentence_transformers python package. "
|
||||||
|
"Please install it with `pip install sentence-transformers`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self.client = sentence_transformers.CrossEncoder(
|
||||||
|
self.model_name, **self.model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||||
|
"""Compute similarity scores using a HuggingFace transformer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_pairs: The list of text text_pairs to score the similarity.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores, one for each pair.
|
||||||
|
"""
|
||||||
|
scores = self.client.predict(text_pairs)
|
||||||
|
return scores
|
@ -0,0 +1,151 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders.base import BaseCrossEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderContentHandler:
|
||||||
|
"""Content handler for CrossEncoder class."""
|
||||||
|
|
||||||
|
content_type = "application/json"
|
||||||
|
accepts = "application/json"
|
||||||
|
|
||||||
|
def transform_input(self, text_pairs: List[Tuple[str, str]]) -> bytes:
|
||||||
|
input_str = json.dumps({"text_pairs": text_pairs})
|
||||||
|
return input_str.encode("utf-8")
|
||||||
|
|
||||||
|
def transform_output(self, output: Any) -> List[float]:
|
||||||
|
response_json = json.loads(output.read().decode("utf-8"))
|
||||||
|
scores = response_json["scores"]
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder):
|
||||||
|
"""SageMaker Inference CrossEncoder endpoint.
|
||||||
|
|
||||||
|
To use, you must supply the endpoint name from your deployed
|
||||||
|
Sagemaker model & the region where it is deployed.
|
||||||
|
|
||||||
|
To authenticate, the AWS client uses the following methods to
|
||||||
|
automatically load credentials:
|
||||||
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
|
||||||
|
If a specific credential profile should be used, you must pass
|
||||||
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||||
|
|
||||||
|
Make sure the credentials / roles used have the required policies to
|
||||||
|
access the Sagemaker endpoint.
|
||||||
|
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
||||||
|
from langchain.embeddings import SagemakerEndpointCrossEncoder
|
||||||
|
endpoint_name = (
|
||||||
|
"my-endpoint-name"
|
||||||
|
)
|
||||||
|
region_name = (
|
||||||
|
"us-west-2"
|
||||||
|
)
|
||||||
|
credentials_profile_name = (
|
||||||
|
"default"
|
||||||
|
)
|
||||||
|
se = SagemakerEndpointCrossEncoder(
|
||||||
|
endpoint_name=endpoint_name,
|
||||||
|
region_name=region_name,
|
||||||
|
credentials_profile_name=credentials_profile_name
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
client: Any #: :meta private:
|
||||||
|
|
||||||
|
endpoint_name: str = ""
|
||||||
|
"""The name of the endpoint from the deployed Sagemaker model.
|
||||||
|
Must be unique within an AWS Region."""
|
||||||
|
|
||||||
|
region_name: str = ""
|
||||||
|
"""The aws region where the Sagemaker model is deployed, eg. `us-west-2`."""
|
||||||
|
|
||||||
|
credentials_profile_name: Optional[str] = None
|
||||||
|
"""The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
|
||||||
|
has either access keys or role information specified.
|
||||||
|
If not specified, the default credential profile or, if on an EC2 instance,
|
||||||
|
credentials from IMDS will be used.
|
||||||
|
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_handler: CrossEncoderContentHandler = CrossEncoderContentHandler()
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict] = None
|
||||||
|
"""Keyword arguments to pass to the model."""
|
||||||
|
|
||||||
|
endpoint_kwargs: Optional[Dict] = None
|
||||||
|
"""Optional attributes passed to the invoke_endpoint
|
||||||
|
function. See `boto3`_. docs for more info.
|
||||||
|
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
try:
|
||||||
|
if values["credentials_profile_name"] is not None:
|
||||||
|
session = boto3.Session(
|
||||||
|
profile_name=values["credentials_profile_name"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# use default credentials
|
||||||
|
session = boto3.Session()
|
||||||
|
|
||||||
|
values["client"] = session.client(
|
||||||
|
"sagemaker-runtime", region_name=values["region_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not load credentials to authenticate with AWS client. "
|
||||||
|
"Please check that credentials in the specified "
|
||||||
|
"profile name are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import boto3 python package. "
|
||||||
|
"Please install it with `pip install boto3`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
|
||||||
|
"""Call out to SageMaker Inference CrossEncoder endpoint."""
|
||||||
|
_endpoint_kwargs = self.endpoint_kwargs or {}
|
||||||
|
|
||||||
|
body = self.content_handler.transform_input(text_pairs)
|
||||||
|
content_type = self.content_handler.content_type
|
||||||
|
accepts = self.content_handler.accepts
|
||||||
|
|
||||||
|
# send request
|
||||||
|
try:
|
||||||
|
response = self.client.invoke_endpoint(
|
||||||
|
EndpointName=self.endpoint_name,
|
||||||
|
Body=body,
|
||||||
|
ContentType=content_type,
|
||||||
|
Accept=accepts,
|
||||||
|
**_endpoint_kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||||
|
|
||||||
|
return self.content_handler.transform_output(response["Body"])
|
@ -0,0 +1 @@
|
|||||||
|
"""Test cross encoder integrations."""
|
@ -0,0 +1,22 @@
|
|||||||
|
"""Test huggingface cross encoders."""
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def _assert(encoder: HuggingFaceCrossEncoder) -> None:
|
||||||
|
query = "I love you"
|
||||||
|
texts = ["I love you", "I like you", "I don't like you", "I hate you"]
|
||||||
|
output = encoder.score([(query, text) for text in texts])
|
||||||
|
|
||||||
|
for i in range(len(texts) - 1):
|
||||||
|
assert output[i] > output[i + 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_cross_encoder() -> None:
|
||||||
|
encoder = HuggingFaceCrossEncoder()
|
||||||
|
_assert(encoder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_cross_encoder_with_designated_model_name() -> None:
|
||||||
|
encoder = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
|
_assert(encoder)
|
@ -6,6 +6,9 @@ from langchain.retrievers.document_compressors.chain_filter import (
|
|||||||
LLMChainFilter,
|
LLMChainFilter,
|
||||||
)
|
)
|
||||||
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
|
from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank
|
||||||
|
from langchain.retrievers.document_compressors.cross_encoder_rerank import (
|
||||||
|
CrossEncoderReranker,
|
||||||
|
)
|
||||||
from langchain.retrievers.document_compressors.embeddings_filter import (
|
from langchain.retrievers.document_compressors.embeddings_filter import (
|
||||||
EmbeddingsFilter,
|
EmbeddingsFilter,
|
||||||
)
|
)
|
||||||
@ -17,5 +20,6 @@ __all__ = [
|
|||||||
"LLMChainExtractor",
|
"LLMChainExtractor",
|
||||||
"LLMChainFilter",
|
"LLMChainFilter",
|
||||||
"CohereRerank",
|
"CohereRerank",
|
||||||
|
"CrossEncoderReranker",
|
||||||
"FlashrankRerank",
|
"FlashrankRerank",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders import BaseCrossEncoder
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
|
from langchain_core.pydantic_v1 import Extra
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderReranker(BaseDocumentCompressor):
|
||||||
|
"""Document compressor that uses CrossEncoder for reranking."""
|
||||||
|
|
||||||
|
model: BaseCrossEncoder
|
||||||
|
"""CrossEncoder model to use for scoring similarity
|
||||||
|
between the query and documents."""
|
||||||
|
top_n: int = 3
|
||||||
|
"""Number of documents to return."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def compress_documents(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
query: str,
|
||||||
|
callbacks: Optional[Callbacks] = None,
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""
|
||||||
|
Rerank documents using CrossEncoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of documents to compress.
|
||||||
|
query: The query to use for compressing the documents.
|
||||||
|
callbacks: Callbacks to run during the compression process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of compressed documents.
|
||||||
|
"""
|
||||||
|
scores = self.model.score([(query, doc.page_content) for doc in documents])
|
||||||
|
docs_with_scores = list(zip(documents, scores))
|
||||||
|
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
||||||
|
return [doc for doc, _ in result[: self.top_n]]
|
@ -0,0 +1,34 @@
|
|||||||
|
"""Integration test for CrossEncoderReranker."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain_community.cross_encoders import FakeCrossEncoder
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank() -> None:
|
||||||
|
texts = [
|
||||||
|
"aaa1",
|
||||||
|
"bbb1",
|
||||||
|
"aaa2",
|
||||||
|
"bbb2",
|
||||||
|
"aaa3",
|
||||||
|
"bbb3",
|
||||||
|
]
|
||||||
|
docs = list(map(lambda text: Document(page_content=text), texts))
|
||||||
|
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
|
||||||
|
actual_docs = compressor.compress_documents(docs, "bbb2")
|
||||||
|
actual = list(map(lambda doc: doc.page_content, actual_docs))
|
||||||
|
expected_returned = ["bbb2", "bbb1", "bbb3"]
|
||||||
|
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
|
||||||
|
assert all([text in actual for text in expected_returned])
|
||||||
|
assert all([text not in actual for text in expected_not_returned])
|
||||||
|
assert actual[0] == "bbb2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_empty() -> None:
|
||||||
|
docs: List[Document] = []
|
||||||
|
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
|
||||||
|
actual_docs = compressor.compress_documents(docs, "query")
|
||||||
|
assert len(actual_docs) == 0
|
Loading…
Reference in New Issue
Block a user