diff --git a/docs/docs/integrations/document_transformers/cross_encoder_reranker.ipynb b/docs/docs/integrations/document_transformers/cross_encoder_reranker.ipynb new file mode 100644 index 00000000000..25c87c0ee8a --- /dev/null +++ b/docs/docs/integrations/document_transformers/cross_encoder_reranker.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc0db1bc", + "metadata": {}, + "source": [ + "# Cross Encoder Reranker\n", + "\n", + "This notebook shows how to implement reranker in a retriever with your own cross encoder from [HuggingFace cross encoder models](https://huggingface.co/cross-encoder) or HuggingFace models that implements cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)). `SagemakerEndpointCrossEncoder` enables you to use these HuggingFace models loaded on Sagemaker.\n", + "\n", + "This builds on top of ideas in the [ContextualCompressionRetriever](/docs/modules/data_connection/retrievers/contextual_compression/). Overall structure of this document came from [Cohere Reranker documentation](/docs/integrations/retrievers/cohere-reranker.ipynb).\n", + "\n", + "For more about why cross encoder can be used as reranking mechanism in conjunction with embeddings for better retrieval, refer to [HuggingFace Cross-Encoders documentation](https://www.sbert.net/examples/applications/cross-encoder/README.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b37bd138-4f3c-4d2c-bc4b-be705ce27a09", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#!pip install faiss sentence_transformers\n", + "\n", + "# OR (depending on Python version)\n", + "\n", + "#!pip install faiss-cpu sentence_transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "28e8dc12", + "metadata": {}, + "outputs": [], + "source": [ + "# Helper function for printing docs\n", + "\n", + "\n", + "def pretty_print_docs(docs):\n", + " print(\n", + " f\"\\n{'-' * 100}\\n\".join(\n", + " [f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6fa3d916", + "metadata": { + "tags": [] + }, + "source": [ + "## Set up the base vector store retriever\n", + "Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can set up the retriever to retrieve a high number (20) of docs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fbcc58f", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "from langchain_community.embeddings import HuggingFaceEmbeddings\n", + "from langchain_community.vectorstores import FAISS\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "\n", + "documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n", + "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n", + "texts = text_splitter.split_documents(documents)\n", + "embeddingsModel = HuggingFaceEmbeddings(\n", + " model_name=\"sentence-transformers/msmarco-distilbert-dot-v5\"\n", + ")\n", + "retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(\n", + " search_kwargs={\"k\": 20}\n", + ")\n", + "\n", + "query = \"What is the plan for the economy?\"\n", + "docs = retriever.get_relevant_documents(query)\n", + "pretty_print_docs(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "b7648612", + "metadata": {}, + "source": [ + "## Doing reranking with CrossEncoderReranker\n", + "Now let's wrap our base retriever with a `ContextualCompressionRetriever`. `CrossEncoderReranker` uses `HuggingFaceCrossEncoder` to rerank the returned results." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "9a658023", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "More infrastructure and innovation in America. \n", + "\n", + "More goods moving faster and cheaper in America. \n", + "\n", + "More jobs where you can earn a good living in America. \n", + "\n", + "And instead of relying on foreign supply chains, let’s make it in America. \n", + "\n", + "Economists call it “increasing the productive capacity of our economy.” \n", + "\n", + "I call it building a better America. \n", + "\n", + "My plan to fight inflation will lower your costs and lower the deficit.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 2:\n", + "\n", + "Second – cut energy costs for families an average of $500 a year by combatting climate change. \n", + "\n", + "Let’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more; lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 3:\n", + "\n", + "Look at cars. \n", + "\n", + "Last year, there weren’t enough semiconductors to make all the cars that people wanted to buy. \n", + "\n", + "And guess what, prices of automobiles went up. \n", + "\n", + "So—we have a choice. \n", + "\n", + "One way to fight inflation is to drive down wages and make Americans poorer. \n", + "\n", + "I have a better plan to fight inflation. \n", + "\n", + "Lower your costs, not your wages. \n", + "\n", + "Make more cars and semiconductors in America. \n", + "\n", + "More infrastructure and innovation in America. \n", + "\n", + "More goods moving faster and cheaper in America.\n" + ] + } + ], + "source": [ + "from langchain.retrievers import ContextualCompressionRetriever\n", + "from langchain.retrievers.document_compressors import CrossEncoderReranker\n", + "from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n", + "\n", + "model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n", + "compressor = CrossEncoderReranker(model=model, top_n=3)\n", + "compression_retriever = ContextualCompressionRetriever(\n", + " base_compressor=compressor, base_retriever=retriever\n", + ")\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\n", + " \"What is the plan for the economy?\"\n", + ")\n", + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "419a2bf3-de4b-4c4d-9a40-4336552f604c", + "metadata": {}, + "source": [ + "## Uploading HuggingFace model to SageMaker endpoint\n", + "\n", + "Refer to [this article](https://www.philschmid.de/custom-inference-huggingface-sagemaker) for general guideline. Here is a simple `inference.py` for creating an endpoint that works with `SagemakerEndpointCrossEncoder`.\n", + "\n", + "It downloads HuggingFace model on the fly, so you do not need to keep the model artifacts such as `pytorch_model.bin` in your `model.tar.gz`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e579c743-40c3-432f-9483-0982e2808f9a", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import logging\n", + "from typing import List\n", + "\n", + "import torch\n", + "from sagemaker_inference import encoder\n", + "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", + "\n", + "PAIRS = \"pairs\"\n", + "SCORES = \"scores\"\n", + "\n", + "\n", + "class CrossEncoder:\n", + " def __init__(self) -> None:\n", + " self.device = (\n", + " torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " )\n", + " logging.info(f\"Using device: {self.device}\")\n", + " model_name = \"BAAI/bge-reranker-base\"\n", + " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " self.model = AutoModelForSequenceClassification.from_pretrained(model_name)\n", + " self.model = self.model.to(self.device)\n", + "\n", + " def __call__(self, pairs: List[List[str]]) -> List[float]:\n", + " with torch.inference_mode():\n", + " inputs = self.tokenizer(\n", + " pairs,\n", + " padding=True,\n", + " truncation=True,\n", + " return_tensors=\"pt\",\n", + " max_length=512,\n", + " )\n", + " inputs = inputs.to(self.device)\n", + " scores = (\n", + " self.model(**inputs, return_dict=True)\n", + " .logits.view(\n", + " -1,\n", + " )\n", + " .float()\n", + " )\n", + "\n", + " return scores.detach().cpu().tolist()\n", + "\n", + "\n", + "def model_fn(model_dir: str) -> CrossEncoder:\n", + " try:\n", + " return CrossEncoder()\n", + " except Exception:\n", + " logging.exception(f\"Failed to load model from: {model_dir}\")\n", + " raise\n", + "\n", + "\n", + "def transform_fn(\n", + " cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str\n", + ") -> bytes:\n", + " payload = json.loads(input_data)\n", + " model_output = cross_encoder(**payload)\n", + " output = {SCORES: model_output}\n", + " return encoder.encode(output, accept)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/cross_encoders/__init__.py b/libs/community/langchain_community/cross_encoders/__init__.py new file mode 100644 index 00000000000..be68809d193 --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/__init__.py @@ -0,0 +1,30 @@ +"""**Cross encoders** are wrappers around cross encoder models from different APIs and + services. + +**Cross encoder models** can be LLMs or not. + +**Class hierarchy:** + +.. code-block:: + + BaseCrossEncoder --> CrossEncoder # Examples: SagemakerEndpointCrossEncoder +""" + + +import logging + +from langchain_community.cross_encoders.base import BaseCrossEncoder +from langchain_community.cross_encoders.fake import FakeCrossEncoder +from langchain_community.cross_encoders.huggingface import HuggingFaceCrossEncoder +from langchain_community.cross_encoders.sagemaker_endpoint import ( + SagemakerEndpointCrossEncoder, +) + +logger = logging.getLogger(__name__) + +__all__ = [ + "BaseCrossEncoder", + "FakeCrossEncoder", + "HuggingFaceCrossEncoder", + "SagemakerEndpointCrossEncoder", +] diff --git a/libs/community/langchain_community/cross_encoders/base.py b/libs/community/langchain_community/cross_encoders/base.py new file mode 100644 index 00000000000..98fa0568980 --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + + +class BaseCrossEncoder(ABC): + """Interface for cross encoder models.""" + + @abstractmethod + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + """Score pairs' similarity. + + Args: + text_pairs: List of pairs of texts. + + Returns: + List of scores. + """ diff --git a/libs/community/langchain_community/cross_encoders/fake.py b/libs/community/langchain_community/cross_encoders/fake.py new file mode 100644 index 00000000000..91e1f702d8c --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/fake.py @@ -0,0 +1,18 @@ +from difflib import SequenceMatcher +from typing import List, Tuple + +from langchain_core.pydantic_v1 import BaseModel + +from langchain_community.cross_encoders.base import BaseCrossEncoder + + +class FakeCrossEncoder(BaseCrossEncoder, BaseModel): + """Fake cross encoder model.""" + + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + scores = list( + map( + lambda pair: SequenceMatcher(None, pair[0], pair[1]).ratio(), text_pairs + ) + ) + return scores diff --git a/libs/community/langchain_community/cross_encoders/huggingface.py b/libs/community/langchain_community/cross_encoders/huggingface.py new file mode 100644 index 00000000000..6cfbceff7a6 --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/huggingface.py @@ -0,0 +1,63 @@ +from typing import Any, Dict, List, Tuple + +from langchain_core.pydantic_v1 import BaseModel, Extra, Field + +from langchain_community.cross_encoders.base import BaseCrossEncoder + +DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base" + + +class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder): + """HuggingFace cross encoder models. + + Example: + .. code-block:: python + + from langchain_community.cross_encoders import HuggingFaceCrossEncoder + + model_name = "BAAI/bge-reranker-base" + model_kwargs = {'device': 'cpu'} + hf = HuggingFaceCrossEncoder( + model_name=model_name, + model_kwargs=model_kwargs + ) + """ + + client: Any #: :meta private: + model_name: str = DEFAULT_MODEL_NAME + """Model name to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + super().__init__(**kwargs) + try: + import sentence_transformers + + except ImportError as exc: + raise ImportError( + "Could not import sentence_transformers python package. " + "Please install it with `pip install sentence-transformers`." + ) from exc + + self.client = sentence_transformers.CrossEncoder( + self.model_name, **self.model_kwargs + ) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + """Compute similarity scores using a HuggingFace transformer model. + + Args: + text_pairs: The list of text text_pairs to score the similarity. + + Returns: + List of scores, one for each pair. + """ + scores = self.client.predict(text_pairs) + return scores diff --git a/libs/community/langchain_community/cross_encoders/sagemaker_endpoint.py b/libs/community/langchain_community/cross_encoders/sagemaker_endpoint.py new file mode 100644 index 00000000000..20d6b85c6a0 --- /dev/null +++ b/libs/community/langchain_community/cross_encoders/sagemaker_endpoint.py @@ -0,0 +1,151 @@ +import json +from typing import Any, Dict, List, Optional, Tuple + +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + +from langchain_community.cross_encoders.base import BaseCrossEncoder + + +class CrossEncoderContentHandler: + """Content handler for CrossEncoder class.""" + + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, text_pairs: List[Tuple[str, str]]) -> bytes: + input_str = json.dumps({"text_pairs": text_pairs}) + return input_str.encode("utf-8") + + def transform_output(self, output: Any) -> List[float]: + response_json = json.loads(output.read().decode("utf-8")) + scores = response_json["scores"] + return scores + + +class SagemakerEndpointCrossEncoder(BaseModel, BaseCrossEncoder): + """SageMaker Inference CrossEncoder endpoint. + + To use, you must supply the endpoint name from your deployed + Sagemaker model & the region where it is deployed. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html + """ + + """ + Example: + .. code-block:: python + + + from langchain.embeddings import SagemakerEndpointCrossEncoder + endpoint_name = ( + "my-endpoint-name" + ) + region_name = ( + "us-west-2" + ) + credentials_profile_name = ( + "default" + ) + se = SagemakerEndpointCrossEncoder( + endpoint_name=endpoint_name, + region_name=region_name, + credentials_profile_name=credentials_profile_name + ) + """ + client: Any #: :meta private: + + endpoint_name: str = "" + """The name of the endpoint from the deployed Sagemaker model. + Must be unique within an AWS Region.""" + + region_name: str = "" + """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" + + credentials_profile_name: Optional[str] = None + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + content_handler: CrossEncoderContentHandler = CrossEncoderContentHandler() + + model_kwargs: Optional[Dict] = None + """Keyword arguments to pass to the model.""" + + endpoint_kwargs: Optional[Dict] = None + """Optional attributes passed to the invoke_endpoint + function. See `boto3`_. docs for more info. + .. _boto3: + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session( + profile_name=values["credentials_profile_name"] + ) + else: + # use default credentials + session = boto3.Session() + + values["client"] = session.client( + "sagemaker-runtime", region_name=values["region_name"] + ) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ImportError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + return values + + def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: + """Call out to SageMaker Inference CrossEncoder endpoint.""" + _endpoint_kwargs = self.endpoint_kwargs or {} + + body = self.content_handler.transform_input(text_pairs) + content_type = self.content_handler.content_type + accepts = self.content_handler.accepts + + # send request + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + return self.content_handler.transform_output(response["Body"]) diff --git a/libs/community/tests/integration_tests/cross_encoders/__init__.py b/libs/community/tests/integration_tests/cross_encoders/__init__.py new file mode 100644 index 00000000000..26b0d562fe5 --- /dev/null +++ b/libs/community/tests/integration_tests/cross_encoders/__init__.py @@ -0,0 +1 @@ +"""Test cross encoder integrations.""" diff --git a/libs/community/tests/integration_tests/cross_encoders/test_huggingface.py b/libs/community/tests/integration_tests/cross_encoders/test_huggingface.py new file mode 100644 index 00000000000..12c5d72b9e1 --- /dev/null +++ b/libs/community/tests/integration_tests/cross_encoders/test_huggingface.py @@ -0,0 +1,22 @@ +"""Test huggingface cross encoders.""" + +from langchain_community.cross_encoders import HuggingFaceCrossEncoder + + +def _assert(encoder: HuggingFaceCrossEncoder) -> None: + query = "I love you" + texts = ["I love you", "I like you", "I don't like you", "I hate you"] + output = encoder.score([(query, text) for text in texts]) + + for i in range(len(texts) - 1): + assert output[i] > output[i + 1] + + +def test_huggingface_cross_encoder() -> None: + encoder = HuggingFaceCrossEncoder() + _assert(encoder) + + +def test_huggingface_cross_encoder_with_designated_model_name() -> None: + encoder = HuggingFaceCrossEncoder(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2") + _assert(encoder) diff --git a/libs/langchain/langchain/retrievers/document_compressors/__init__.py b/libs/langchain/langchain/retrievers/document_compressors/__init__.py index a4c1456cf24..6a32c453a3c 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/__init__.py +++ b/libs/langchain/langchain/retrievers/document_compressors/__init__.py @@ -6,6 +6,9 @@ from langchain.retrievers.document_compressors.chain_filter import ( LLMChainFilter, ) from langchain.retrievers.document_compressors.cohere_rerank import CohereRerank +from langchain.retrievers.document_compressors.cross_encoder_rerank import ( + CrossEncoderReranker, +) from langchain.retrievers.document_compressors.embeddings_filter import ( EmbeddingsFilter, ) @@ -17,5 +20,6 @@ __all__ = [ "LLMChainExtractor", "LLMChainFilter", "CohereRerank", + "CrossEncoderReranker", "FlashrankRerank", ] diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py new file mode 100644 index 00000000000..e4047fc0723 --- /dev/null +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import operator +from typing import Optional, Sequence + +from langchain_community.cross_encoders import BaseCrossEncoder +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.pydantic_v1 import Extra + + +class CrossEncoderReranker(BaseDocumentCompressor): + """Document compressor that uses CrossEncoder for reranking.""" + + model: BaseCrossEncoder + """CrossEncoder model to use for scoring similarity + between the query and documents.""" + top_n: int = 3 + """Number of documents to return.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + Rerank documents using CrossEncoder. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ + scores = self.model.score([(query, doc.page_content) for doc in documents]) + docs_with_scores = list(zip(documents, scores)) + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + return [doc for doc, _ in result[: self.top_n]] diff --git a/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cross_encoder_reranker.py b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cross_encoder_reranker.py new file mode 100644 index 00000000000..29404aada1a --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/document_compressors/test_cross_encoder_reranker.py @@ -0,0 +1,34 @@ +"""Integration test for CrossEncoderReranker.""" +from typing import List + +from langchain_community.cross_encoders import FakeCrossEncoder +from langchain_core.documents import Document + +from langchain.retrievers.document_compressors import CrossEncoderReranker + + +def test_rerank() -> None: + texts = [ + "aaa1", + "bbb1", + "aaa2", + "bbb2", + "aaa3", + "bbb3", + ] + docs = list(map(lambda text: Document(page_content=text), texts)) + compressor = CrossEncoderReranker(model=FakeCrossEncoder()) + actual_docs = compressor.compress_documents(docs, "bbb2") + actual = list(map(lambda doc: doc.page_content, actual_docs)) + expected_returned = ["bbb2", "bbb1", "bbb3"] + expected_not_returned = ["aaa1", "aaa2", "aaa3"] + assert all([text in actual for text in expected_returned]) + assert all([text not in actual for text in expected_not_returned]) + assert actual[0] == "bbb2" + + +def test_rerank_empty() -> None: + docs: List[Document] = [] + compressor = CrossEncoderReranker(model=FakeCrossEncoder()) + actual_docs = compressor.compress_documents(docs, "query") + assert len(actual_docs) == 0