mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
map rerank chain (#516)
add a chain that applies a prompt to all inputs and then returns not only an answer but scores it add examples for question answering and question answering with sources
This commit is contained in:
parent
948eee9fe1
commit
8dfad874a2
@ -6,7 +6,7 @@ For more information on specific use cases as well as different methods for **fe
|
||||
|
||||
This documentation now picks up from after you've fetched your documents - now what?
|
||||
How do you pass them to the language model in a format it can understand?
|
||||
There are a few different methods, or chains, for doing so. LangChain supports three of the more common ones - and
|
||||
There are a few different methods, or chains, for doing so. LangChain supports four of the more common ones - and
|
||||
we are actively looking to include more, so if you have any ideas please reach out! Note that there is not
|
||||
one best method - the decision of which one to use is often very context specific. In order from simplest to
|
||||
most complex:
|
||||
@ -39,3 +39,13 @@ asking the LLM to refine the output based on the new document.
|
||||
**Pros:** Can pull in more relevant context, and may be less lossy than `MapReduceDocumentsChain`.
|
||||
|
||||
**Cons:** Requires many more calls to the LLM than `StuffDocumentsChain`. The calls are also NOT independent, meaning they cannot be paralleled like `MapReduceDocumentsChain`. There is also some potential dependencies on the ordering of the documents.
|
||||
|
||||
|
||||
## Map-Rerank
|
||||
This method involves running an initial prompt on each chunk of data, that not only tries to complete a
|
||||
task but also gives a score for how certain it is in its answer. The responses are then
|
||||
ranked according to this score, and the highest score is returned.
|
||||
|
||||
**Pros:** Similar pros as `MapReduceDocumentsChain`. Compared to `MapReduceDocumentsChain`, it requires fewer calls.
|
||||
|
||||
**Cons:** Cannot combine information between documents. This means it is most useful when you expect there to be a single simple answer in a single document.
|
||||
|
@ -7,7 +7,7 @@
|
||||
"source": [
|
||||
"# Question Answering with Sources\n",
|
||||
"\n",
|
||||
"This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers three different chain types: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
|
||||
"This notebook walks through how to use LangChain for question answering with sources over a list of documents. It covers four different chain types: `stuff`, `map_reduce`, `refine`,`map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -259,7 +259,7 @@
|
||||
"source": [
|
||||
"**Intermediate Steps**\n",
|
||||
"\n",
|
||||
"We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_refine_steps` variable."
|
||||
"We can also return the intermediate steps for `refine` chains, should we want to inspect them. This is done with the `return_intermediate_steps` variable."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -297,10 +297,87 @@
|
||||
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07ff756e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The `map-rerank` Chain\n",
|
||||
"\n",
|
||||
"This sections shows results of using the `map-rerank` Chain to do question answering with sources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "46b52ef9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", metadata_keys=['source'], return_intermediate_steps=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "7ce2da04",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Justice Breyer\"\n",
|
||||
"result = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "cbdcd3c5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result[\"output_text\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "6f0b3d03",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'answer': ' The President thanked Justice Breyer for his service and honored him for dedicating his life to serve the country.',\n",
|
||||
" 'score': '100'},\n",
|
||||
" {'answer': ' This document does not answer the question', 'score': '0'},\n",
|
||||
" {'answer': ' This document does not answer the question', 'score': '0'},\n",
|
||||
" {'answer': ' This document does not answer the question', 'score': '0'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"result[\"intermediate_steps\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aa2b8db9",
|
||||
"id": "e66b8160",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -7,7 +7,7 @@
|
||||
"source": [
|
||||
"# Question Answering\n",
|
||||
"\n",
|
||||
"This notebook walks through how to use LangChain for question answering over a list of documents. It covers three different types of chaings: `stuff`, `map_reduce`, and `refine`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
|
||||
"This notebook walks through how to use LangChain for question answering over a list of documents. It covers four different types of chaings: `stuff`, `map_reduce`, `refine`, `map-rerank`. For a more in depth explanation of what these chain types are, see [here](../combine_docs.md)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -21,7 +21,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 11,
|
||||
"id": "17fcbc0f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -34,7 +34,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 12,
|
||||
"id": "291f0117",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -49,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 13,
|
||||
"id": "fd9666a9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -59,7 +59,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 14,
|
||||
"id": "d1eaf6e6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 15,
|
||||
"id": "a16e3453",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -294,6 +294,94 @@
|
||||
"source": [
|
||||
"chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "521a77cb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The `map-rerank` Chain\n",
|
||||
"\n",
|
||||
"This sections shows results of using the `map-rerank` Chain to do question answering with sources."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "e2bfe203",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"map_rerank\", return_intermediate_steps=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "5c28880c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Justice Breyer\"\n",
|
||||
"results = chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "80ac2db3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. '"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results[\"output_text\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "b428fcb9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'answer': ' The president thanked Justice Breyer for his service and honored him for dedicating his life to serving the country. ',\n",
|
||||
" 'score': '100'},\n",
|
||||
" {'answer': \" The president said that Justice Breyer is a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that since she's been nominated, she's received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans, and that she is a consensus builder.\",\n",
|
||||
" 'score': '100'},\n",
|
||||
" {'answer': ' The president did not mention Justice Breyer in this context.',\n",
|
||||
" 'score': '0'},\n",
|
||||
" {'answer': ' The president did not mention Justice Breyer in the given context. ',\n",
|
||||
" 'score': '0'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results[\"intermediate_steps\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4f86521",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
113
langchain/chains/combine_documents/map_rerank.py
Normal file
113
langchain/chains/combine_documents/map_rerank.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Combining documents by mapping a chain over them first, then reranking results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import RegexParser
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain, BaseModel):
|
||||
"""Combining documents by mapping a chain over them, then reranking results."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
rank_key: str
|
||||
"""Key in output of llm_chain to rank on."""
|
||||
answer_key: str
|
||||
"""Key in output of llm_chain to return as answer."""
|
||||
metadata_keys: Optional[List[str]] = None
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
_output_keys = super().output_keys
|
||||
if self.return_intermediate_steps:
|
||||
_output_keys = _output_keys + ["intermediate_steps"]
|
||||
if self.metadata_keys is not None:
|
||||
_output_keys += self.metadata_keys
|
||||
return _output_keys
|
||||
|
||||
@root_validator()
|
||||
def validate_llm_output(cls, values: Dict) -> Dict:
|
||||
"""Validate that the combine chain outputs a dictionary."""
|
||||
output_parser = values["llm_chain"].prompt.output_parser
|
||||
if not isinstance(output_parser, RegexParser):
|
||||
raise ValueError(
|
||||
"Output parser of llm_chain should be a RegexParser,"
|
||||
f" got {output_parser}"
|
||||
)
|
||||
output_keys = output_parser.output_keys
|
||||
if values["rank_key"] not in output_keys:
|
||||
raise ValueError(
|
||||
f"Got {values['rank_key']} as key to rank on, but did not find "
|
||||
f"it in the llm_chain output keys ({output_keys})"
|
||||
)
|
||||
if values["answer_key"] not in output_keys:
|
||||
raise ValueError(
|
||||
f"Got {values['answer_key']} as key to return, but did not find "
|
||||
f"it in the llm_chain output keys ({output_keys})"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
if "document_variable_name" not in values:
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain input_variables"
|
||||
)
|
||||
else:
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
return values
|
||||
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = self.llm_chain.apply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
|
||||
)
|
||||
typed_results = cast(List[dict], results)
|
||||
|
||||
sorted_res = sorted(
|
||||
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
||||
)
|
||||
output, document = sorted_res[0]
|
||||
extra_info = {}
|
||||
if self.metadata_keys is not None:
|
||||
for key in self.metadata_keys:
|
||||
extra_info[key] = document.metadata[key]
|
||||
if self.return_intermediate_steps:
|
||||
extra_info["intermediate_steps"] = results
|
||||
return output[self.answer_key], extra_info
|
@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -11,6 +12,7 @@ from langchain.chains.qa_with_sources import (
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
@ -22,6 +24,25 @@ class LoadingCallable(Protocol):
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
rank_key=rank_key,
|
||||
answer_key=answer_key,
|
||||
document_variable_name=document_variable_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
@ -137,6 +158,7 @@ def load_qa_with_sources_chain(
|
||||
"stuff": _load_stuff_chain,
|
||||
"map_reduce": _load_map_reduce_chain,
|
||||
"refine": _load_refine_chain,
|
||||
"map_rerank": _load_map_rerank_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
|
@ -3,11 +3,13 @@ from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import (
|
||||
map_reduce_prompt,
|
||||
map_rerank_prompt,
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
@ -22,6 +24,25 @@ class LoadingCallable(Protocol):
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
rank_key=rank_key,
|
||||
answer_key=answer_key,
|
||||
document_variable_name=document_variable_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
@ -132,6 +153,7 @@ def load_qa_chain(
|
||||
"stuff": _load_stuff_chain,
|
||||
"map_reduce": _load_map_reduce_chain,
|
||||
"refine": _load_refine_chain,
|
||||
"map_rerank": _load_map_rerank_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
|
66
langchain/chains/question_answering/map_rerank_prompt.py
Normal file
66
langchain/chains/question_answering/map_rerank_prompt.py
Normal file
@ -0,0 +1,66 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.base import RegexParser
|
||||
|
||||
output_parser = RegexParser(
|
||||
regex=r"(.*?)\nScore: (.*)",
|
||||
output_keys=["answer", "score"],
|
||||
)
|
||||
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
In addition to giving an answer, also return a score of how fully it answered the user's question. This should be in the following format:
|
||||
|
||||
Question: [question here]
|
||||
Helpful Answer: [answer here]
|
||||
Score: [score between 0 and 100]
|
||||
|
||||
How to determine the score:
|
||||
- Higher is a better answer
|
||||
- Better responds fully to the asked question, with sufficient level of detail
|
||||
- If you do not know the answer based on the context, that should be a score of 0
|
||||
- Don't be overconfident!
|
||||
|
||||
Example #1
|
||||
|
||||
Context:
|
||||
---------
|
||||
Apples are red
|
||||
---------
|
||||
Question: what color are apples?
|
||||
Helpful Answer: red
|
||||
Score: 100
|
||||
|
||||
Example #2
|
||||
|
||||
Context:
|
||||
---------
|
||||
it was night and the witness forgot his glasses. he was not sure if it was a sports car or an suv
|
||||
---------
|
||||
Question: what type was the car?
|
||||
Helpful Answer: a sports car or an suv
|
||||
Score: 60
|
||||
|
||||
Example #3
|
||||
|
||||
Context:
|
||||
---------
|
||||
Pears are either red or orange
|
||||
---------
|
||||
Question: what color are apples?
|
||||
Helpful Answer: This document does not answer the question
|
||||
Score: 0
|
||||
|
||||
Begin!
|
||||
|
||||
Context:
|
||||
---------
|
||||
{context}
|
||||
---------
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
PROMPT = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["context", "question"],
|
||||
output_parser=output_parser,
|
||||
)
|
Loading…
Reference in New Issue
Block a user