mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
Add a ListRerank
document compressor (#13311)
- **Description:** This PR adds a new document compressor called `ListRerank`. It's derived from `BaseDocumentCompressor`. It's a near exact implementation of introduced by this paper: [Zero-Shot Listwise Document Reranking with a Large Language Model](https://arxiv.org/pdf/2305.02156.pdf) which it finds to outperform pointwise reranking, which is somewhat implemented in LangChain as [LLMChainFilter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py). - **Issue:** None - **Dependencies:** None - **Tag maintainer:** @hwchase17 @izzymsft - **Twitter handle:** @HarrisEMitchell Notes: 1. I didn't add anything to `docs`. I wasn't exactly sure which patterns to follow as [cohere reranker is under Retrievers](https://python.langchain.com/docs/integrations/retrievers/cohere-reranker) with other external document retrieval integrations, but other contextual compression is [here](https://python.langchain.com/docs/modules/data_connection/retrievers/contextual_compression/). Happy to contribute to either with some direction. 2. I followed syntax, docstrings, implementation patterns, etc. as well as I could looking at nearby modules. One thing I didn't do was put the default prompt in a separate `.py` file like [Chain Filter](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_filter_prompt.py) and [Chain Extract](https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/retrievers/document_compressors/chain_extract_prompt.py). Happy to follow that pattern if it would be preferred. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
@@ -15,6 +15,9 @@ from langchain.retrievers.document_compressors.cross_encoder_rerank import (
|
||||
from langchain.retrievers.document_compressors.embeddings_filter import (
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
from langchain.retrievers.document_compressors.listwise_rerank import (
|
||||
LLMListwiseRerank,
|
||||
)
|
||||
|
||||
_module_lookup = {
|
||||
"FlashrankRerank": "langchain_community.document_compressors.flashrank_rerank",
|
||||
@@ -31,6 +34,7 @@ def __getattr__(name: str) -> Any:
|
||||
__all__ = [
|
||||
"DocumentCompressorPipeline",
|
||||
"EmbeddingsFilter",
|
||||
"LLMListwiseRerank",
|
||||
"LLMChainExtractor",
|
||||
"LLMChainFilter",
|
||||
"CohereRerank",
|
||||
|
@@ -0,0 +1,137 @@
|
||||
"""Filter that uses an LLM to rerank documents listwise and select top-k."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
|
||||
_default_system_tmpl = """{context}
|
||||
|
||||
Sort the Documents by their relevance to the Query."""
|
||||
_DEFAULT_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[("system", _default_system_tmpl), ("human", "{query}")],
|
||||
)
|
||||
|
||||
|
||||
def _get_prompt_input(input_: dict) -> Dict[str, Any]:
|
||||
"""Return the compression chain input."""
|
||||
documents = input_["documents"]
|
||||
context = ""
|
||||
for index, doc in enumerate(documents):
|
||||
context += f"Document ID: {index}\n```{doc.page_content}```\n\n"
|
||||
context += f"Documents = [Document ID: 0, ..., Document ID: {len(documents) - 1}]"
|
||||
return {"query": input_["query"], "context": context}
|
||||
|
||||
|
||||
def _parse_ranking(results: dict) -> List[Document]:
|
||||
ranking = results["ranking"]
|
||||
docs = results["documents"]
|
||||
return [docs[i] for i in ranking.ranked_document_ids]
|
||||
|
||||
|
||||
class LLMListwiseRerank(BaseDocumentCompressor):
|
||||
"""Document compressor that uses `Zero-Shot Listwise Document Reranking`.
|
||||
|
||||
Adapted from: https://arxiv.org/pdf/2305.02156.pdf
|
||||
|
||||
``LLMListwiseRerank`` uses a language model to rerank a list of documents based on
|
||||
their relevance to a query.
|
||||
|
||||
**NOTE**: requires that underlying model implement ``with_structured_output``.
|
||||
|
||||
Example usage:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.retrievers.document_compressors.listwise_rerank import (
|
||||
LLMListwiseRerank,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
documents = [
|
||||
Document("Sally is my friend from school"),
|
||||
Document("Steve is my friend from home"),
|
||||
Document("I didn't always like yogurt"),
|
||||
Document("I wonder why it's called football"),
|
||||
Document("Where's waldo"),
|
||||
]
|
||||
|
||||
reranker = LLMListwiseRerank.from_llm(
|
||||
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
|
||||
)
|
||||
compressed_docs = reranker.compress_documents(documents, "Who is steve")
|
||||
assert len(compressed_docs) == 3
|
||||
assert "Steve" in compressed_docs[0].page_content
|
||||
"""
|
||||
|
||||
reranker: Runnable[Dict, List[Document]]
|
||||
"""LLM-based reranker to use for filtering documents. Expected to take in a dict
|
||||
with 'documents: Sequence[Document]' and 'query: str' keys and output a
|
||||
List[Document]."""
|
||||
|
||||
top_n: int = 3
|
||||
"""Number of documents to return."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
results = self.reranker.invoke(
|
||||
{"documents": documents, "query": query}, config={"callbacks": callbacks}
|
||||
)
|
||||
return results[: self.top_n]
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
*,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> "LLMListwiseRerank":
|
||||
"""Create a LLMListwiseRerank document compressor from a language model.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for filtering. **Must implement
|
||||
BaseLanguageModel.with_structured_output().**
|
||||
prompt: The prompt to use for the filter.
|
||||
**kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A LLMListwiseRerank document compressor that uses the given language model.
|
||||
"""
|
||||
|
||||
if llm.with_structured_output == BaseLanguageModel.with_structured_output:
|
||||
raise ValueError(
|
||||
f"llm of type {type(llm)} does not implement `with_structured_output`."
|
||||
)
|
||||
|
||||
class RankDocuments(BaseModel):
|
||||
"""Rank the documents by their relevance to the user question.
|
||||
Rank from most to least relevant."""
|
||||
|
||||
ranked_document_ids: List[int] = Field(
|
||||
...,
|
||||
description=(
|
||||
"The integer IDs of the documents, sorted from most to least "
|
||||
"relevant to the user question."
|
||||
),
|
||||
)
|
||||
|
||||
_prompt = prompt if prompt is not None else _DEFAULT_PROMPT
|
||||
reranker = RunnablePassthrough.assign(
|
||||
ranking=RunnableLambda(_get_prompt_input)
|
||||
| _prompt
|
||||
| llm.with_structured_output(RankDocuments)
|
||||
) | RunnableLambda(_parse_ranking)
|
||||
return cls(reranker=reranker, **kwargs)
|
@@ -0,0 +1,22 @@
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank
|
||||
|
||||
|
||||
def test_list_rerank() -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
documents = [
|
||||
Document("Sally is my friend from school"),
|
||||
Document("Steve is my friend from home"),
|
||||
Document("I didn't always like yogurt"),
|
||||
Document("I wonder why it's called football"),
|
||||
Document("Where's waldo"),
|
||||
]
|
||||
|
||||
reranker = LLMListwiseRerank.from_llm(
|
||||
llm=ChatOpenAI(model="gpt-3.5-turbo"), top_n=3
|
||||
)
|
||||
compressed_docs = reranker.compress_documents(documents, "Who is steve")
|
||||
assert len(compressed_docs) == 3
|
||||
assert "Steve" in compressed_docs[0].page_content
|
@@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from langchain.retrievers.document_compressors.listwise_rerank import LLMListwiseRerank
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
def test__list_rerank_init() -> None:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
LLMListwiseRerank.from_llm(
|
||||
llm=ChatOpenAI(api_key="foo"), # type: ignore[arg-type]
|
||||
top_n=10,
|
||||
)
|
Reference in New Issue
Block a user