Add a ListRerank document compressor (#13311)

- **Description:** This PR adds a new document compressor called
`ListRerank`. It's derived from `BaseDocumentCompressor`. It's a near
exact implementation of introduced by this paper: [Zero-Shot Listwise
Document Reranking with a Large Language
Model](https://arxiv.org/pdf/2305.02156.pdf) which it finds to
outperform pointwise reranking, which is somewhat implemented in
LangChain as
[LLMChainFilter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py).
- **Issue:** None
- **Dependencies:** None
- **Tag maintainer:** @hwchase17 @izzymsft
- **Twitter handle:** @HarrisEMitchell

Notes:
1. I didn't add anything to `docs`. I wasn't exactly sure which patterns
to follow as [cohere reranker is under
Retrievers](https://python.langchain.com/docs/integrations/retrievers/cohere-reranker)
with other external document retrieval integrations, but other
contextual compression is
[here](https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/).
Happy to contribute to either with some direction.
2. I followed syntax, docstrings, implementation patterns, etc. as well
as I could looking at nearby modules. One thing I didn't do was put the
default prompt in a separate `.py` file like [Chain
Filter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py)
and [Chain
Extract](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py).
Happy to follow that pattern if it would be preferred.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Evan Harris
2024-07-18 14:34:38 -06:00
committed by GitHub
parent 4c651ba13a
commit 61ea7bf60b
6 changed files with 228 additions and 1 deletions

View File

@@ -15,6 +15,9 @@ from langchain.retrievers.document_compressors.cross_encoder_rerank import (
from langchain.retrievers.document_compressors.embeddings_filter import (
EmbeddingsFilter,
)
from langchain.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)
_module_lookup = {
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
@@ -31,6 +34,7 @@ def __getattr__(name: str) -> Any:
__all__ = [
"DocumentCompressorPipeline",
"EmbeddingsFilter",
"LLMListwiseRerank",
"LLMChainExtractor",
"LLMChainFilter",
"CohereRerank",

View File

@@ -0,0 +1,137 @@
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
from typing import Any, Dict, List, Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
_default_system_tmpl = """{context}
Sort the Documents by their relevance to the Query."""
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
[("system", _default_system_tmpl), ("human", "{query}")],
)
def _get_prompt_input(input_: dict) -> Dict[str, Any]:
"""Return the compression chain input."""
documents = input_["documents"]
context = ""
for index, doc in enumerate(documents):
context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
return {"query": input_["query"], "context": context}
def _parse_ranking(results: dict) -> List[Document]:
ranking = results["ranking"]
docs = results["documents"]
return [docs[i] for i in ranking.ranked_document_ids]
class LLMListwiseRerank(BaseDocumentCompressor):
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`.
Adapted from: https://arxiv.org/pdf/2305.02156.pdf
``LLMListwiseRerank`` uses a language model to rerank a list of documents based on
their relevance to a query.
**NOTE**: requires that underlying model implement ``with_structured_output``.
Example usage:
.. code-block:: python
from langchain.retrievers.document_compressors.listwise_rerank import (
LLMListwiseRerank,
)
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
documents = [
Document("Sally is my friend from school"),
Document("Steve is my friend from home"),
Document("I didn't always like yogurt"),
Document("I wonder why it's called football"),
Document("Where's waldo"),
]
reranker = LLMListwiseRerank.from_llm(
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
)
compressed_docs = reranker.compress_documents(documents, "Who is steve")
assert len(compressed_docs) == 3
assert "Steve" in compressed_docs[0].page_content
"""
reranker: Runnable[Dict, List[Document]]
"""LLM-based reranker to use for filtering documents. Expected to take in a dict
with 'documents: Sequence[Document]' and 'query: str' keys and output a
List[Document]."""
top_n: int = 3
"""Number of documents to return."""
class Config:
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
results = self.reranker.invoke(
{"documents": documents, "query": query}, config={"callbacks": callbacks}
)
return results[: self.top_n]
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> "LLMListwiseRerank":
"""Create a LLMListwiseRerank document compressor from a language model.
Args:
llm: The language model to use for filtering. **Must implement
BaseLanguageModel.with_structured_output().**
prompt: The prompt to use for the filter.
**kwargs: Additional arguments to pass to the constructor.
Returns:
A LLMListwiseRerank document compressor that uses the given language model.
"""
if llm.with_structured_output == BaseLanguageModel.with_structured_output:
raise ValueError(
f"llm of type {type(llm)} does not implement `with_structured_output`."
)
class RankDocuments(BaseModel):
"""Rank the documents by their relevance to the user question.
Rank from most to least relevant."""
ranked_document_ids: List[int] = Field(
...,
description=(
"The integer IDs of the documents, sorted from most to least "
"relevant to the user question."
),
)
_prompt = prompt if prompt is not None else _DEFAULT_PROMPT
reranker = RunnablePassthrough.assign(
ranking=RunnableLambda(_get_prompt_input)
| _prompt
| llm.with_structured_output(RankDocuments)
) | RunnableLambda(_parse_ranking)
return cls(reranker=reranker, **kwargs)

View File

@@ -0,0 +1,22 @@
from langchain_core.documents import Document
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank
def test_list_rerank() -> None:
from langchain_openai import ChatOpenAI
documents = [
Document("Sally is my friend from school"),
Document("Steve is my friend from home"),
Document("I didn't always like yogurt"),
Document("I wonder why it's called football"),
Document("Where's waldo"),
]
reranker = LLMListwiseRerank.from_llm(
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
)
compressed_docs = reranker.compress_documents(documents, "Who is steve")
assert len(compressed_docs) == 3
assert "Steve" in compressed_docs[0].page_content

View File

@@ -0,0 +1,13 @@
import pytest
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank
@pytest.mark.requires("langchain_openai")
def test__list_rerank_init() -> None:
from langchain_openai import ChatOpenAI
LLMListwiseRerank.from_llm(
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
top_n=10,
)