diff --git a/docs/modules/chains/combine_docs.md b/docs/modules/chains/combine_docs.md index ff9e17371f0..4e9348ba7ae 100644 --- a/docs/modules/chains/combine_docs.md +++ b/docs/modules/chains/combine_docs.md @@ -6,7 +6,7 @@ For more information on specific use cases as well as different methods for **fe This documentation now picks up from after you've fetched your documents - now what? How do you pass them to the language model in a format it can understand? -There are a few different methods, or chains, for doing so. LangChain supports three of the more common ones - and +There are a few different methods, or chains, for doing so. LangChain supports four of the more common ones - and we are actively looking to include more, so if you have any ideas please reach out! Note that there is not one best method - the decision of which one to use is often very context specific. In order from simplest to most complex: @@ -39,3 +39,13 @@ asking the LLM to refine the output based on the new document. **Pros:** Can pull in more relevant context, and may be less lossy than `MapReduceDocumentsChain`. **Cons:** Requires many more calls to the LLM than `StuffDocumentsChain`. The calls are also NOT independent, meaning they cannot be paralleled like `MapReduceDocumentsChain`. There is also some potential dependencies on the ordering of the documents. + + +## Map-Rerank +This method involves running an initial prompt on each chunk of data, that not only tries to complete a +task but also gives a score for how certain it is in its answer. The responses are then +ranked according to this score, and the highest score is returned. + +**Pros:** Similar pros as `MapReduceDocumentsChain`. Compared to `MapReduceDocumentsChain`, it requires fewer calls. + +**Cons:** Cannot combine information between documents. This means it is most useful when you expect there to be a single simple answer in a single document. diff --git a/docs/modules/chains/combine_docs_examples/qa_with_sources.ipynb b/docs/modules/chains/combine_docs_examples/qa_with_sources.ipynb index 94741dd87f3..b7a285b33af 100644 --- a/docs/modules/chains/combine_docs_examples/qa_with_sources.ipynb +++ b/docs/modules/chains/combine_docs_examples/qa_with_sources.ipynb @@ -7,7 +7,7 @@ "source": [ "# Question Answering with Sources\n", "\n", - "This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers three different chain types: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)." + "This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers four different chain types: `stuff`, `map_reduce`, `refine`,`map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)." ] }, { @@ -259,7 +259,7 @@ "source": [ "**Intermediate Steps**\n", "\n", - "We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_refine_steps` variable." + "We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_intermediate_steps` variable." ] }, { @@ -297,10 +297,87 @@ "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" ] }, + { + "cell_type": "markdown", + "id": "07ff756e", + "metadata": {}, + "source": [ + "## The `map-rerank` Chain\n", + "\n", + "This sections shows results of using the `map-rerank` Chain to do question answering with sources." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "46b52ef9", + "metadata": {}, + "outputs": [], + "source": [ + "chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", metadata_keys=['source'], return_intermediate_steps=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7ce2da04", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Justice Breyer\"\n", + "result = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cbdcd3c5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[\"output_text\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6f0b3d03", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'answer': ' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.',\n", + " 'score': '100'},\n", + " {'answer': ' This document does not answer the question', 'score': '0'},\n", + " {'answer': ' This document does not answer the question', 'score': '0'},\n", + " {'answer': ' This document does not answer the question', 'score': '0'}]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result[\"intermediate_steps\"]" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "aa2b8db9", + "id": "e66b8160", "metadata": {}, "outputs": [], "source": [] diff --git a/docs/modules/chains/combine_docs_examples/question_answering.ipynb b/docs/modules/chains/combine_docs_examples/question_answering.ipynb index 9aa2cc367a9..0e9191766b0 100644 --- a/docs/modules/chains/combine_docs_examples/question_answering.ipynb +++ b/docs/modules/chains/combine_docs_examples/question_answering.ipynb @@ -7,7 +7,7 @@ "source": [ "# Question Answering\n", "\n", - "This notebook walks through how to use LangChain for question answering over a list of documents. It covers three different types of chaings: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)." + "This notebook walks through how to use LangChain for question answering over a list of documents. It covers four different types of chaings: `stuff`, `map_reduce`, `refine`, `map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)." ] }, { @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "id": "17fcbc0f", "metadata": {}, "outputs": [], @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 12, "id": "291f0117", "metadata": {}, "outputs": [], @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 13, "id": "fd9666a9", "metadata": {}, "outputs": [], @@ -59,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 14, "id": "d1eaf6e6", "metadata": {}, "outputs": [], @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "id": "a16e3453", "metadata": {}, "outputs": [], @@ -294,6 +294,94 @@ "source": [ "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" ] + }, + { + "cell_type": "markdown", + "id": "521a77cb", + "metadata": {}, + "source": [ + "## The `map-rerank` Chain\n", + "\n", + "This sections shows results of using the `map-rerank` Chain to do question answering with sources." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e2bfe203", + "metadata": {}, + "outputs": [], + "source": [ + "chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", return_intermediate_steps=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "5c28880c", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What did the president say about Justice Breyer\"\n", + "results = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "80ac2db3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. '" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results[\"output_text\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b428fcb9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'answer': ' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. ',\n", + " 'score': '100'},\n", + " {'answer': \" The president said that Justice Breyer is a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that since she's been nominated, she's received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans, and that she is a consensus builder.\",\n", + " 'score': '100'},\n", + " {'answer': ' The president did not mention Justice Breyer in this context.',\n", + " 'score': '0'},\n", + " {'answer': ' The president did not mention Justice Breyer in the given context. ',\n", + " 'score': '0'}]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results[\"intermediate_steps\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4f86521", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py new file mode 100644 index 00000000000..d1c8092ca60 --- /dev/null +++ b/langchain/chains/combine_documents/map_rerank.py @@ -0,0 +1,113 @@ +"""Combining documents by mapping a chain over them first, then reranking results.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple, cast + +from pydantic import BaseModel, Extra, root_validator + +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.llm import LLMChain +from langchain.docstore.document import Document +from langchain.prompts.base import RegexParser + + +class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel): + """Combining documents by mapping a chain over them, then reranking results.""" + + llm_chain: LLMChain + """Chain to apply to each document individually.""" + document_variable_name: str + """The variable name in the llm_chain to put the documents in. + If only one variable in the llm_chain, this need not be provided.""" + rank_key: str + """Key in output of llm_chain to rank on.""" + answer_key: str + """Key in output of llm_chain to return as answer.""" + metadata_keys: Optional[List[str]] = None + return_intermediate_steps: bool = False + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def output_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + _output_keys = super().output_keys + if self.return_intermediate_steps: + _output_keys = _output_keys + ["intermediate_steps"] + if self.metadata_keys is not None: + _output_keys += self.metadata_keys + return _output_keys + + @root_validator() + def validate_llm_output(cls, values: Dict) -> Dict: + """Validate that the combine chain outputs a dictionary.""" + output_parser = values["llm_chain"].prompt.output_parser + if not isinstance(output_parser, RegexParser): + raise ValueError( + "Output parser of llm_chain should be a RegexParser," + f" got {output_parser}" + ) + output_keys = output_parser.output_keys + if values["rank_key"] not in output_keys: + raise ValueError( + f"Got {values['rank_key']} as key to rank on, but did not find " + f"it in the llm_chain output keys ({output_keys})" + ) + if values["answer_key"] not in output_keys: + raise ValueError( + f"Got {values['answer_key']} as key to return, but did not find " + f"it in the llm_chain output keys ({output_keys})" + ) + return values + + @root_validator(pre=True) + def get_default_document_variable_name(cls, values: Dict) -> Dict: + """Get default document variable name, if not provided.""" + if "document_variable_name" not in values: + llm_chain_variables = values["llm_chain"].prompt.input_variables + if len(llm_chain_variables) == 1: + values["document_variable_name"] = llm_chain_variables[0] + else: + raise ValueError( + "document_variable_name must be provided if there are " + "multiple llm_chain input_variables" + ) + else: + llm_chain_variables = values["llm_chain"].prompt.input_variables + if values["document_variable_name"] not in llm_chain_variables: + raise ValueError( + f"document_variable_name {values['document_variable_name']} was " + f"not found in llm_chain input_variables: {llm_chain_variables}" + ) + return values + + def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + """Combine documents in a map rerank manner. + + Combine by mapping first chain over all documents, then reranking the results. + """ + results = self.llm_chain.apply_and_parse( + # FYI - this is parallelized and so it is fast. + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + ) + typed_results = cast(List[dict], results) + + sorted_res = sorted( + zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) + ) + output, document = sorted_res[0] + extra_info = {} + if self.metadata_keys is not None: + for key in self.metadata_keys: + extra_info[key] = document.metadata[key] + if self.return_intermediate_steps: + extra_info["intermediate_steps"] = results + return output[self.answer_key], extra_info diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index b6e7b51f3a4..4af17329798 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain @@ -11,6 +12,7 @@ from langchain.chains.qa_with_sources import ( refine_prompts, stuff_prompt, ) +from langchain.chains.question_answering import map_rerank_prompt from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -22,6 +24,25 @@ class LoadingCallable(Protocol): """Callable to load the combine documents chain.""" +def _load_map_rerank_chain( + llm: BaseLLM, + prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, + verbose: bool = False, + document_variable_name: str = "context", + rank_key: str = "score", + answer_key: str = "answer", + **kwargs: Any, +) -> MapRerankDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + return MapRerankDocumentsChain( + llm_chain=llm_chain, + rank_key=rank_key, + answer_key=answer_key, + document_variable_name=document_variable_name, + **kwargs, + ) + + def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, @@ -137,6 +158,7 @@ def load_qa_with_sources_chain( "stuff": _load_stuff_chain, "map_reduce": _load_map_reduce_chain, "refine": _load_refine_chain, + "map_rerank": _load_map_rerank_chain, } if chain_type not in loader_mapping: raise ValueError( diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 1e9bc7cdb80..a2b64445f5e 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -3,11 +3,13 @@ from typing import Any, Mapping, Optional, Protocol from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import ( map_reduce_prompt, + map_rerank_prompt, refine_prompts, stuff_prompt, ) @@ -22,6 +24,25 @@ class LoadingCallable(Protocol): """Callable to load the combine documents chain.""" +def _load_map_rerank_chain( + llm: BaseLLM, + prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, + verbose: bool = False, + document_variable_name: str = "context", + rank_key: str = "score", + answer_key: str = "answer", + **kwargs: Any, +) -> MapRerankDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + return MapRerankDocumentsChain( + llm_chain=llm_chain, + rank_key=rank_key, + answer_key=answer_key, + document_variable_name=document_variable_name, + **kwargs, + ) + + def _load_stuff_chain( llm: BaseLLM, prompt: BasePromptTemplate = stuff_prompt.PROMPT, @@ -132,6 +153,7 @@ def load_qa_chain( "stuff": _load_stuff_chain, "map_reduce": _load_map_reduce_chain, "refine": _load_refine_chain, + "map_rerank": _load_map_rerank_chain, } if chain_type not in loader_mapping: raise ValueError( diff --git a/langchain/chains/question_answering/map_rerank_prompt.py b/langchain/chains/question_answering/map_rerank_prompt.py new file mode 100644 index 00000000000..ab68048b0dd --- /dev/null +++ b/langchain/chains/question_answering/map_rerank_prompt.py @@ -0,0 +1,66 @@ +# flake8: noqa +from langchain.prompts import PromptTemplate +from langchain.prompts.base import RegexParser + +output_parser = RegexParser( + regex=r"(.*?)\nScore: (.*)", + output_keys=["answer", "score"], +) + +prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + +In addition to giving an answer, also return a score of how fully it answered the user's question. This should be in the following format: + +Question: [question here] +Helpful Answer: [answer here] +Score: [score between 0 and 100] + +How to determine the score: +- Higher is a better answer +- Better responds fully to the asked question, with sufficient level of detail +- If you do not know the answer based on the context, that should be a score of 0 +- Don't be overconfident! + +Example #1 + +Context: +--------- +Apples are red +--------- +Question: what color are apples? +Helpful Answer: red +Score: 100 + +Example #2 + +Context: +--------- +it was night and the witness forgot his glasses. he was not sure if it was a sports car or an suv +--------- +Question: what type was the car? +Helpful Answer: a sports car or an suv +Score: 60 + +Example #3 + +Context: +--------- +Pears are either red or orange +--------- +Question: what color are apples? +Helpful Answer: This document does not answer the question +Score: 0 + +Begin! + +Context: +--------- +{context} +--------- +Question: {question} +Helpful Answer:""" +PROMPT = PromptTemplate( + template=prompt_template, + input_variables=["context", "question"], + output_parser=output_parser, +)