mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
add documentation on how to load different chain types (#595)
This commit is contained in:
parent
956416c150
commit
d574bf0a27
@ -46,7 +46,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa = VectorDBQA.from_llm(llm=OpenAI(), vectorstore=docsearch)"
|
||||
"qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"stuff\", vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -58,7 +58,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The president said that Ketanji Brown Jackson is one of the nation’s top legal minds and that she will continue Justice Breyer’s legacy of excellence.'"
|
||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, from a family of public school educators and police officers, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@ -71,6 +71,91 @@
|
||||
"qa.run(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c28f1f64",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chain Type\n",
|
||||
"You can easily specify different chain types to load and use in the VectorDBQA chain. For a more detailed walkthrough of these types, please see [this notebook](question_answering.ipynb).\n",
|
||||
"\n",
|
||||
"There are two ways to load different chain types. First, you can specify the chain type argument in the `from_chain_type` method. This allows you to pass in the name of the chain type you want to use. For example, in the below we change the chain type to `map_reduce`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "22d2417d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"map_reduce\", vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "43204ad1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, from a family of public school educators and police officers, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"qa.run(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60368f38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](question_answering.ipynb)) and then pass that directly to the the VectorDBQA chain with the `combine_documents_chain` parameter. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "7b403f0d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||
"qa_chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"stuff\")\n",
|
||||
"qa = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "9e04a9ac",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"qa.run(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b8c37f7",
|
||||
@ -87,7 +172,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"qa = VectorDBQA.from_llm(llm=OpenAI(), vectorstore=docsearch, return_source_documents=True)"
|
||||
"qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"stuff\", return_source_documents=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -26,7 +26,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "17d1306e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -41,7 +41,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "0e745d99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -51,7 +51,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "f42d79dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -63,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "8aa571ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -73,26 +73,69 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"id": "aa859d4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI\n",
|
||||
"\n",
|
||||
"chain = VectorDBQAWithSourcesChain.from_llm(OpenAI(temperature=0), vectorstore=docsearch)"
|
||||
"chain = VectorDBQAWithSourcesChain.from_chain_type(OpenAI(temperature=0), chain_type=\"stuff\", vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 7,
|
||||
"id": "8ba36fa7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'answer': ' The president thanked Justice Breyer for his service.',\n",
|
||||
"{'answer': ' The president thanked Justice Breyer for his service.\\n',\n",
|
||||
" 'sources': '30-pl'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "718ecbda",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chain Type\n",
|
||||
"You can easily specify different chain types to load and use in the VectorDBQAWithSourcesChain chain. For a more detailed walkthrough of these types, please see [this notebook](qa_with_sources.ipynb).\n",
|
||||
"\n",
|
||||
"There are two ways to load different chain types. First, you can specify the chain type argument in the `from_chain_type` method. This allows you to pass in the name of the chain type you want to use. For example, in the below we change the chain type to `map_reduce`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "8b35b30a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = VectorDBQAWithSourcesChain.from_chain_type(OpenAI(temperature=0), chain_type=\"map_reduce\", vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "58bd424f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'answer': ' The president honored Justice Stephen Breyer for his service.\\n',\n",
|
||||
" 'sources': '30-pl'}"
|
||||
]
|
||||
},
|
||||
@ -104,11 +147,53 @@
|
||||
"source": [
|
||||
"chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "21e14eed",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](qa_with_sources.ipynb)) and then pass that directly to the the VectorDBQA chain with the `combine_documents_chain` parameter. For example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "af35f0c6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.qa_with_sources import load_qa_with_sources_chain\n",
|
||||
"qa_chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"stuff\")\n",
|
||||
"qa = VectorDBQAWithSourcesChain(combine_document_chain=qa_chain, vectorstore=docsearch)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "c91fdc8a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'answer': ' The president honored Justice Stephen Breyer for his service.\\n',\n",
|
||||
" 'sources': '30-pl'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.0 64-bit ('llm-env')",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -122,7 +207,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.0"
|
||||
"version": "3.10.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -1,169 +1,4 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_with_sources import (
|
||||
map_reduce_prompt,
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
rank_key=rank_key,
|
||||
answer_key=answer_key,
|
||||
document_variable_name=document_variable_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
if collapse_llm is not None:
|
||||
raise ValueError(
|
||||
"collapse_llm provided, but collapse_prompt was not: please "
|
||||
"provide one or stop providing collapse_llm."
|
||||
)
|
||||
else:
|
||||
_collapse_llm = collapse_llm or llm
|
||||
collapse_chain = StuffDocumentsChain(
|
||||
llm_chain=LLMChain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering with sources chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering with sources.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
"map_reduce": _load_map_reduce_chain,
|
||||
"refine": _load_refine_chain,
|
||||
"map_rerank": _load_map_rerank_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
_func: LoadingCallable = loader_mapping[chain_type]
|
||||
return _func(llm, verbose=verbose, **kwargs)
|
||||
__all__ = ["load_qa_with_sources_chain"]
|
||||
|
@ -12,6 +12,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
|
||||
from langchain.chains.qa_with_sources.map_reduce_prompt import (
|
||||
COMBINE_PROMPT,
|
||||
EXAMPLE_PROMPT,
|
||||
@ -25,7 +26,7 @@ from langchain.prompts.base import BasePromptTemplate
|
||||
class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
"""Question answering with sources over documents."""
|
||||
|
||||
combine_document_chain: BaseCombineDocumentsChain
|
||||
combine_documents_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine documents."""
|
||||
question_key: str = "question" #: :meta private:
|
||||
input_docs_key: str = "docs" #: :meta private:
|
||||
@ -55,10 +56,18 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
document_variable_name="context",
|
||||
)
|
||||
return cls(
|
||||
combine_document_chain=combine_document_chain,
|
||||
combine_documents_chain=combine_document_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls, llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||
) -> BaseQAWithSourcesChain:
|
||||
"""Load chain from chain type."""
|
||||
combine_document_chain = load_qa_with_sources_chain(llm, chain_type=chain_type)
|
||||
return cls(combine_documents_chain=combine_document_chain, **kwargs)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@ -81,11 +90,11 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
"""
|
||||
return [self.answer_key, self.sources_answer_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_combine_chain_can_be_constructed(cls, values: Dict) -> Dict:
|
||||
"""Validate that the combine chain can be constructed."""
|
||||
# Try to construct the combine documents chains.
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_naming(cls, values: Dict) -> Dict:
|
||||
"""Fix backwards compatability in naming."""
|
||||
if "combine_document_chain" in values:
|
||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||
return values
|
||||
|
||||
@abstractmethod
|
||||
@ -94,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC):
|
||||
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
docs = self._get_docs(inputs)
|
||||
answer, _ = self.combine_document_chain.combine_docs(docs, **inputs)
|
||||
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
|
||||
if "SOURCES: " in answer:
|
||||
answer, sources = answer.split("SOURCES: ")
|
||||
else:
|
||||
|
169
langchain/chains/qa_with_sources/loading.py
Normal file
169
langchain/chains/qa_with_sources/loading.py
Normal file
@ -0,0 +1,169 @@
|
||||
"""Load question answering with sources chains."""
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.qa_with_sources import (
|
||||
map_reduce_prompt,
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.chains.question_answering import map_rerank_prompt
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
rank_key: str = "score",
|
||||
answer_key: str = "answer",
|
||||
**kwargs: Any,
|
||||
) -> MapRerankDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return MapRerankDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
rank_key=rank_key,
|
||||
answer_key=answer_key,
|
||||
document_variable_name=document_variable_name,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "summaries",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||
combine_document_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_chain,
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
if collapse_prompt is None:
|
||||
collapse_chain = None
|
||||
if collapse_llm is not None:
|
||||
raise ValueError(
|
||||
"collapse_llm provided, but collapse_prompt was not: please "
|
||||
"provide one or stop providing collapse_llm."
|
||||
)
|
||||
else:
|
||||
_collapse_llm = collapse_llm or llm
|
||||
collapse_chain = StuffDocumentsChain(
|
||||
llm_chain=LLMChain(
|
||||
llm=_collapse_llm,
|
||||
prompt=collapse_prompt,
|
||||
verbose=verbose,
|
||||
),
|
||||
document_variable_name=combine_document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
combine_document_chain=combine_document_chain,
|
||||
document_variable_name=map_reduce_document_variable_name,
|
||||
collapse_document_chain=collapse_chain,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
initial_response_name=initial_response_name,
|
||||
document_prompt=document_prompt,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_with_sources_chain(
|
||||
llm: BaseLLM,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering with sources chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering with sources.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
"map_reduce": _load_map_reduce_chain,
|
||||
"refine": _load_refine_chain,
|
||||
"map_rerank": _load_map_rerank_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
_func: LoadingCallable = loader_mapping[chain_type]
|
||||
return _func(llm, verbose=verbose, **kwargs)
|
@ -9,6 +9,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
@ -104,19 +105,24 @@ class VectorDBQA(Chain, BaseModel):
|
||||
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
cls, llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
||||
) -> VectorDBQA:
|
||||
"""Load chain from chain type."""
|
||||
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""Run similarity search and llm on input query.
|
||||
|
||||
If inputs contains 'return_source_documents' as 'True', returns
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
the retrieved documents as well under the key 'source_documents'.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
res = vectordbqa({'query': 'This is my query', 'return_source_documents': True})
|
||||
res = vectordbqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
question = inputs[self.input_key]
|
||||
|
Loading…
Reference in New Issue
Block a user