diff --git a/docs/extras/modules/data_connection/document_transformers/post_retrieval/long_context_reorder.ipynb b/docs/extras/modules/data_connection/document_transformers/post_retrieval/long_context_reorder.ipynb new file mode 100644 index 00000000000..b77aa2f3acf --- /dev/null +++ b/docs/extras/modules/data_connection/document_transformers/post_retrieval/long_context_reorder.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc0db1bc", + "metadata": {}, + "source": [ + "# Lost in the middle: The problem with long contexts\n", + "\n", + "No matter the architecture of your model, there is a sustancial performance degradation when you include 10+ retrieved documents.\n", + "In brief: When models must access relevant information in the middle of long contexts, then tend to ignore the provided documents.\n", + "See: https://arxiv.org/abs//2307.03172\n", + "\n", + "To avoid this issue you can re-order documents after retrieval to avoid performance degradation." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "49cbcd8e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='This is a document about the Boston Celtics', metadata={}),\n", + " Document(page_content='The Celtics are my favourite team.', metadata={}),\n", + " Document(page_content='L. Kornet is one of the best Celtics players.', metadata={}),\n", + " Document(page_content='The Boston Celtics won the game by 20 points', metadata={}),\n", + " Document(page_content='Larry Bird was an iconic NBA player.', metadata={}),\n", + " Document(page_content='Elden Ring is one of the best games in the last 15 years.', metadata={}),\n", + " Document(page_content='Basquetball is a great sport.', metadata={}),\n", + " Document(page_content='I simply love going to the movies', metadata={}),\n", + " Document(page_content='Fly me to the moon is one of my favourite songs.', metadata={}),\n", + " Document(page_content='This is just a random text.', metadata={})]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "import chromadb\n", + "from langchain.vectorstores import Chroma\n", + "from langchain.embeddings import HuggingFaceEmbeddings\n", + "from langchain.document_transformers import (\n", + " LongContextReorder,\n", + ")\n", + "from langchain.chains import StuffDocumentsChain, LLMChain\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import OpenAI\n", + "\n", + "# Get embeddings.\n", + "embeddings = HuggingFaceEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n", + "\n", + "texts = [\n", + " \"Basquetball is a great sport.\",\n", + " \"Fly me to the moon is one of my favourite songs.\",\n", + " \"The Celtics are my favourite team.\",\n", + " \"This is a document about the Boston Celtics\",\n", + " \"I simply love going to the movies\",\n", + " \"The Boston Celtics won the game by 20 points\",\n", + " \"This is just a random text.\",\n", + " \"Elden Ring is one of the best games in the last 15 years.\",\n", + " \"L. Kornet is one of the best Celtics players.\",\n", + " \"Larry Bird was an iconic NBA player.\",\n", + "]\n", + "\n", + "# Create a retriever\n", + "retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(\n", + " search_kwargs={\"k\": 10}\n", + ")\n", + "query = \"What can you tell me about the Celtics?\"\n", + "\n", + "# Get relevant documents ordered by relevance score\n", + "docs = retriever.get_relevant_documents(query)\n", + "docs" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "34fb9d6e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='The Celtics are my favourite team.', metadata={}),\n", + " Document(page_content='The Boston Celtics won the game by 20 points', metadata={}),\n", + " Document(page_content='Elden Ring is one of the best games in the last 15 years.', metadata={}),\n", + " Document(page_content='I simply love going to the movies', metadata={}),\n", + " Document(page_content='This is just a random text.', metadata={}),\n", + " Document(page_content='Fly me to the moon is one of my favourite songs.', metadata={}),\n", + " Document(page_content='Basquetball is a great sport.', metadata={}),\n", + " Document(page_content='Larry Bird was an iconic NBA player.', metadata={}),\n", + " Document(page_content='L. Kornet is one of the best Celtics players.', metadata={}),\n", + " Document(page_content='This is a document about the Boston Celtics', metadata={})]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Reorder the documents:\n", + "# Less relevant document will be at the middle of the list and more\n", + "# relevant elements at begining / end.\n", + "reordering = LongContextReorder()\n", + "reordered_docs = reordering.transform_documents(docs)\n", + "\n", + "# Confirm that the 4 relevant documents are at begining and end.\n", + "reordered_docs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ceccab87", + "metadata": {}, + "outputs": [], + "source": [ + "# We prepare and run a custom Stuff chain with reordered docs as context.\n", + "\n", + "# Override prompts\n", + "document_prompt = PromptTemplate(\n", + " input_variables=[\"page_content\"], template=\"{page_content}\"\n", + ")\n", + "document_variable_name = \"context\"\n", + "llm = OpenAI()\n", + "stuff_prompt_override = \"\"\"Given this text extracts:\n", + "-----\n", + "{context}\n", + "-----\n", + "Please answer the following question:\n", + "{query}\"\"\"\n", + "prompt = PromptTemplate(\n", + " template=stuff_prompt_override, input_variables=[\"context\", \"query\"]\n", + ")\n", + "\n", + "# Instantiate the chain\n", + "llm_chain = LLMChain(llm=llm, prompt=prompt)\n", + "chain = StuffDocumentsChain(\n", + " llm_chain=llm_chain,\n", + " document_prompt=document_prompt,\n", + " document_variable_name=document_variable_name,\n", + ")\n", + "chain.run(input_documents=reordered_docs, query=query)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb b/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb index e553442916c..7dfbaaab0c5 100644 --- a/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb +++ b/docs/extras/modules/data_connection/retrievers/integrations/merger_retriever.ipynb @@ -137,6 +137,36 @@ " base_compressor=pipeline, base_retriever=lotr\n", ")" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8f68956e", + "metadata": {}, + "source": [ + "## Re-order results to avoid performance degradation.\n", + "No matter the architecture of your model, there is a sustancial performance degradation when you include 10+ retrieved documents.\n", + "In brief: When models must access relevant information in the middle of long contexts, then tend to ignore the provided documents.\n", + "See: https://arxiv.org/abs//2307.03172" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "007283f3", + "metadata": {}, + "outputs": [], + "source": [ + "# You can use an additional document transformer to reorder documents after removing redudance.\n", + "from langchain.document_transformers import LongContextReorder\n", + "\n", + "filter = EmbeddingsRedundantFilter(embeddings=filter_embeddings)\n", + "reordering = LongContextReorder()\n", + "pipeline = DocumentCompressorPipeline(transformers=[filter, reordering])\n", + "compression_retriever_reordered = ContextualCompressionRetriever(\n", + " base_compressor=pipeline, base_retriever=lotr\n", + ")" + ] } ], "metadata": { diff --git a/langchain/document_transformers/__init__.py b/langchain/document_transformers/__init__.py index a71d9bd0b17..9f1489219fa 100644 --- a/langchain/document_transformers/__init__.py +++ b/langchain/document_transformers/__init__.py @@ -8,6 +8,7 @@ from langchain.document_transformers.embeddings_redundant_filter import ( EmbeddingsRedundantFilter, get_stateful_documents, ) +from langchain.document_transformers.long_context_reorder import LongContextReorder __all__ = [ "DoctranQATransformer", @@ -16,6 +17,7 @@ __all__ = [ "EmbeddingsClusteringFilter", "EmbeddingsRedundantFilter", "get_stateful_documents", + "LongContextReorder", "OpenAIMetadataTagger", ] diff --git a/langchain/document_transformers/long_context_reorder.py b/langchain/document_transformers/long_context_reorder.py new file mode 100644 index 00000000000..5debbed5ee8 --- /dev/null +++ b/langchain/document_transformers/long_context_reorder.py @@ -0,0 +1,44 @@ +"""Reorder documents""" +from typing import Any, List, Sequence + +from pydantic import BaseModel + +from langchain.schema import BaseDocumentTransformer, Document + + +def _litm_reordering(documents: List[Document]) -> List[Document]: + """Los in the middle reorder: the most relevant will be at the + middle of the list and more relevant elements at beginning / end. + See: https://arxiv.org/abs//2307.03172""" + + documents.reverse() + reordered_result = [] + for i, value in enumerate(documents): + if i % 2 == 1: + reordered_result.append(value) + else: + reordered_result.insert(0, value) + return reordered_result + + +class LongContextReorder(BaseDocumentTransformer, BaseModel): + """Lost in the middle: + Performance degrades when models must access relevant information + in the middle of long contexts. + See: https://arxiv.org/abs//2307.03172""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Reorders documents.""" + return _litm_reordering(list(documents)) + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError diff --git a/tests/integration_tests/test_long_context_reorder.py b/tests/integration_tests/test_long_context_reorder.py new file mode 100644 index 00000000000..4a8a26cd0e4 --- /dev/null +++ b/tests/integration_tests/test_long_context_reorder.py @@ -0,0 +1,35 @@ +"""Integration test for doc reordering.""" +from langchain.document_transformers.long_context_reorder import LongContextReorder +from langchain.embeddings import OpenAIEmbeddings +from langchain.vectorstores import Chroma + + +def test_long_context_reorder() -> None: + """Test Lost in the middle reordering get_relevant_docs.""" + texts = [ + "Basquetball is a great sport.", + "Fly me to the moon is one of my favourite songs.", + "The Celtics are my favourite team.", + "This is a document about the Boston Celtics", + "I simply love going to the movies", + "The Boston Celtics won the game by 20 points", + "This is just a random text.", + "Elden Ring is one of the best games in the last 15 years.", + "L. Kornet is one of the best Celtics players.", + "Larry Bird was an iconic NBA player.", + ] + embeddings = OpenAIEmbeddings() + retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever( + search_kwargs={"k": 10} + ) + reordering = LongContextReorder() + docs = retriever.get_relevant_documents("Tell me about the Celtics") + actual = reordering.transform_documents(docs) + + # First 2 and Last 2 elements must contain the most relevant + first_and_last = list(actual[:2]) + list(actual[-2:]) + assert len(actual) == 10 + assert texts[2] in [d.page_content for d in first_and_last] + assert texts[3] in [d.page_content for d in first_and_last] + assert texts[5] in [d.page_content for d in first_and_last] + assert texts[8] in [d.page_content for d in first_and_last]