diff --git a/docs/modules/chains/combine_docs_examples/vector_db_qa.ipynb b/docs/modules/chains/combine_docs_examples/vector_db_qa.ipynb index 74ef7d183f8..3f90eccdba8 100644 --- a/docs/modules/chains/combine_docs_examples/vector_db_qa.ipynb +++ b/docs/modules/chains/combine_docs_examples/vector_db_qa.ipynb @@ -46,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "qa = VectorDBQA.from_llm(llm=OpenAI(), vectorstore=docsearch)" + "qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"stuff\", vectorstore=docsearch)" ] }, { @@ -58,7 +58,7 @@ { "data": { "text/plain": [ - "' The president said that Ketanji Brown Jackson is one of the nation’s top legal minds and that she will continue Justice Breyer’s legacy of excellence.'" + "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, from a family of public school educators and police officers, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" ] }, "execution_count": 4, @@ -71,6 +71,91 @@ "qa.run(query)" ] }, + { + "cell_type": "markdown", + "id": "c28f1f64", + "metadata": {}, + "source": [ + "## Chain Type\n", + "You can easily specify different chain types to load and use in the VectorDBQA chain. For a more detailed walkthrough of these types, please see [this notebook](question_answering.ipynb).\n", + "\n", + "There are two ways to load different chain types. First, you can specify the chain type argument in the `from_chain_type` method. This allows you to pass in the name of the chain type you want to use. For example, in the below we change the chain type to `map_reduce`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "22d2417d", + "metadata": {}, + "outputs": [], + "source": [ + "qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"map_reduce\", vectorstore=docsearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "43204ad1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, from a family of public school educators and police officers, a consensus builder, and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "qa.run(query)" + ] + }, + { + "cell_type": "markdown", + "id": "60368f38", + "metadata": {}, + "source": [ + "The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](question_answering.ipynb)) and then pass that directly to the the VectorDBQA chain with the `combine_documents_chain` parameter. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "7b403f0d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.question_answering import load_qa_chain\n", + "qa_chain = load_qa_chain(OpenAI(temperature=0), chain_type=\"stuff\")\n", + "qa = VectorDBQA(combine_documents_chain=qa_chain, vectorstore=docsearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "9e04a9ac", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "qa.run(query)" + ] + }, { "cell_type": "markdown", "id": "0b8c37f7", @@ -87,7 +172,7 @@ "metadata": {}, "outputs": [], "source": [ - "qa = VectorDBQA.from_llm(llm=OpenAI(), vectorstore=docsearch, return_source_documents=True)" + "qa = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type=\"stuff\", return_source_documents=True)" ] }, { diff --git a/docs/modules/chains/combine_docs_examples/vector_db_qa_with_sources.ipynb b/docs/modules/chains/combine_docs_examples/vector_db_qa_with_sources.ipynb index e39b91eb431..c03fa18319c 100644 --- a/docs/modules/chains/combine_docs_examples/vector_db_qa_with_sources.ipynb +++ b/docs/modules/chains/combine_docs_examples/vector_db_qa_with_sources.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "17d1306e", "metadata": {}, "outputs": [], @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "0e745d99", "metadata": {}, "outputs": [], @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "f42d79dc", "metadata": {}, "outputs": [], @@ -63,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "8aa571ae", "metadata": {}, "outputs": [], @@ -73,26 +73,69 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "aa859d4c", "metadata": {}, "outputs": [], "source": [ "from langchain import OpenAI\n", "\n", - "chain = VectorDBQAWithSourcesChain.from_llm(OpenAI(temperature=0), vectorstore=docsearch)" + "chain = VectorDBQAWithSourcesChain.from_chain_type(OpenAI(temperature=0), chain_type=\"stuff\", vectorstore=docsearch)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "8ba36fa7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'answer': ' The president thanked Justice Breyer for his service.',\n", + "{'answer': ' The president thanked Justice Breyer for his service.\\n',\n", + " 'sources': '30-pl'}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)" + ] + }, + { + "cell_type": "markdown", + "id": "718ecbda", + "metadata": {}, + "source": [ + "## Chain Type\n", + "You can easily specify different chain types to load and use in the VectorDBQAWithSourcesChain chain. For a more detailed walkthrough of these types, please see [this notebook](qa_with_sources.ipynb).\n", + "\n", + "There are two ways to load different chain types. First, you can specify the chain type argument in the `from_chain_type` method. This allows you to pass in the name of the chain type you want to use. For example, in the below we change the chain type to `map_reduce`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8b35b30a", + "metadata": {}, + "outputs": [], + "source": [ + "chain = VectorDBQAWithSourcesChain.from_chain_type(OpenAI(temperature=0), chain_type=\"map_reduce\", vectorstore=docsearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "58bd424f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'answer': ' The president honored Justice Stephen Breyer for his service.\\n',\n", " 'sources': '30-pl'}" ] }, @@ -104,11 +147,53 @@ "source": [ "chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)" ] + }, + { + "cell_type": "markdown", + "id": "21e14eed", + "metadata": {}, + "source": [ + "The above way allows you to really simply change the chain_type, but it does provide a ton of flexibility over parameters to that chain type. If you want to control those parameters, you can load the chain directly (as you did in [this notebook](qa_with_sources.ipynb)) and then pass that directly to the the VectorDBQA chain with the `combine_documents_chain` parameter. For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "af35f0c6", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.qa_with_sources import load_qa_with_sources_chain\n", + "qa_chain = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type=\"stuff\")\n", + "qa = VectorDBQAWithSourcesChain(combine_document_chain=qa_chain, vectorstore=docsearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c91fdc8a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'answer': ' The president honored Justice Stephen Breyer for his service.\\n',\n", + " 'sources': '30-pl'}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain({\"question\": \"What did the president say about Justice Breyer\"}, return_only_outputs=True)" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.0 64-bit ('llm-env')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -122,7 +207,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.0" + "version": "3.10.9" }, "vscode": { "interpreter": { diff --git a/langchain/chains/qa_with_sources/__init__.py b/langchain/chains/qa_with_sources/__init__.py index 4af17329798..b1d18e832a3 100644 --- a/langchain/chains/qa_with_sources/__init__.py +++ b/langchain/chains/qa_with_sources/__init__.py @@ -1,169 +1,4 @@ """Load question answering with sources chains.""" -from typing import Any, Mapping, Optional, Protocol +from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain -from langchain.chains.combine_documents.base import BaseCombineDocumentsChain -from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain -from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain -from langchain.chains.combine_documents.refine import RefineDocumentsChain -from langchain.chains.combine_documents.stuff import StuffDocumentsChain -from langchain.chains.llm import LLMChain -from langchain.chains.qa_with_sources import ( - map_reduce_prompt, - refine_prompts, - stuff_prompt, -) -from langchain.chains.question_answering import map_rerank_prompt -from langchain.llms.base import BaseLLM -from langchain.prompts.base import BasePromptTemplate - - -class LoadingCallable(Protocol): - """Interface for loading the combine documents chain.""" - - def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: - """Callable to load the combine documents chain.""" - - -def _load_map_rerank_chain( - llm: BaseLLM, - prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, - verbose: bool = False, - document_variable_name: str = "context", - rank_key: str = "score", - answer_key: str = "answer", - **kwargs: Any, -) -> MapRerankDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) - return MapRerankDocumentsChain( - llm_chain=llm_chain, - rank_key=rank_key, - answer_key=answer_key, - document_variable_name=document_variable_name, - **kwargs, - ) - - -def _load_stuff_chain( - llm: BaseLLM, - prompt: BasePromptTemplate = stuff_prompt.PROMPT, - document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT, - document_variable_name: str = "summaries", - verbose: Optional[bool] = None, - **kwargs: Any, -) -> StuffDocumentsChain: - llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) - return StuffDocumentsChain( - llm_chain=llm_chain, - document_variable_name=document_variable_name, - document_prompt=document_prompt, - verbose=verbose, - **kwargs, - ) - - -def _load_map_reduce_chain( - llm: BaseLLM, - question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, - combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, - document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, - combine_document_variable_name: str = "summaries", - map_reduce_document_variable_name: str = "context", - collapse_prompt: Optional[BasePromptTemplate] = None, - reduce_llm: Optional[BaseLLM] = None, - collapse_llm: Optional[BaseLLM] = None, - verbose: Optional[bool] = None, - **kwargs: Any, -) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) - _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) - combine_document_chain = StuffDocumentsChain( - llm_chain=reduce_chain, - document_variable_name=combine_document_variable_name, - document_prompt=document_prompt, - verbose=verbose, - ) - if collapse_prompt is None: - collapse_chain = None - if collapse_llm is not None: - raise ValueError( - "collapse_llm provided, but collapse_prompt was not: please " - "provide one or stop providing collapse_llm." - ) - else: - _collapse_llm = collapse_llm or llm - collapse_chain = StuffDocumentsChain( - llm_chain=LLMChain( - llm=_collapse_llm, - prompt=collapse_prompt, - verbose=verbose, - ), - document_variable_name=combine_document_variable_name, - document_prompt=document_prompt, - ) - return MapReduceDocumentsChain( - llm_chain=map_chain, - combine_document_chain=combine_document_chain, - document_variable_name=map_reduce_document_variable_name, - collapse_document_chain=collapse_chain, - verbose=verbose, - **kwargs, - ) - - -def _load_refine_chain( - llm: BaseLLM, - question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, - refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, - document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, - document_variable_name: str = "context_str", - initial_response_name: str = "existing_answer", - refine_llm: Optional[BaseLLM] = None, - verbose: Optional[bool] = None, - **kwargs: Any, -) -> RefineDocumentsChain: - initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) - _refine_llm = refine_llm or llm - refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) - return RefineDocumentsChain( - initial_llm_chain=initial_chain, - refine_llm_chain=refine_chain, - document_variable_name=document_variable_name, - initial_response_name=initial_response_name, - document_prompt=document_prompt, - verbose=verbose, - **kwargs, - ) - - -def load_qa_with_sources_chain( - llm: BaseLLM, - chain_type: str = "stuff", - verbose: Optional[bool] = None, - **kwargs: Any, -) -> BaseCombineDocumentsChain: - """Load question answering with sources chain. - - Args: - llm: Language Model to use in the chain. - chain_type: Type of document combining chain to use. Should be one of "stuff", - "map_reduce", and "refine". - verbose: Whether chains should be run in verbose mode or not. Note that this - applies to all chains that make up the final chain. - - Returns: - A chain to use for question answering with sources. - """ - loader_mapping: Mapping[str, LoadingCallable] = { - "stuff": _load_stuff_chain, - "map_reduce": _load_map_reduce_chain, - "refine": _load_refine_chain, - "map_rerank": _load_map_rerank_chain, - } - if chain_type not in loader_mapping: - raise ValueError( - f"Got unsupported chain type: {chain_type}. " - f"Should be one of {loader_mapping.keys()}" - ) - _func: LoadingCallable = loader_mapping[chain_type] - return _func(llm, verbose=verbose, **kwargs) +__all__ = ["load_qa_with_sources_chain"] diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 96b02d07201..628db71bc2e 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -12,6 +12,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain +from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain from langchain.chains.qa_with_sources.map_reduce_prompt import ( COMBINE_PROMPT, EXAMPLE_PROMPT, @@ -25,7 +26,7 @@ from langchain.prompts.base import BasePromptTemplate class BaseQAWithSourcesChain(Chain, BaseModel, ABC): """Question answering with sources over documents.""" - combine_document_chain: BaseCombineDocumentsChain + combine_documents_chain: BaseCombineDocumentsChain """Chain to use to combine documents.""" question_key: str = "question" #: :meta private: input_docs_key: str = "docs" #: :meta private: @@ -55,10 +56,18 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): document_variable_name="context", ) return cls( - combine_document_chain=combine_document_chain, + combine_documents_chain=combine_document_chain, **kwargs, ) + @classmethod + def from_chain_type( + cls, llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any + ) -> BaseQAWithSourcesChain: + """Load chain from chain type.""" + combine_document_chain = load_qa_with_sources_chain(llm, chain_type=chain_type) + return cls(combine_documents_chain=combine_document_chain, **kwargs) + class Config: """Configuration for this pydantic object.""" @@ -81,11 +90,11 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): """ return [self.answer_key, self.sources_answer_key] - @root_validator() - def validate_combine_chain_can_be_constructed(cls, values: Dict) -> Dict: - """Validate that the combine chain can be constructed.""" - # Try to construct the combine documents chains. - + @root_validator(pre=True) + def validate_naming(cls, values: Dict) -> Dict: + """Fix backwards compatability in naming.""" + if "combine_document_chain" in values: + values["combine_documents_chain"] = values.pop("combine_document_chain") return values @abstractmethod @@ -94,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, BaseModel, ABC): def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: docs = self._get_docs(inputs) - answer, _ = self.combine_document_chain.combine_docs(docs, **inputs) + answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs) if "SOURCES: " in answer: answer, sources = answer.split("SOURCES: ") else: diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py new file mode 100644 index 00000000000..4af17329798 --- /dev/null +++ b/langchain/chains/qa_with_sources/loading.py @@ -0,0 +1,169 @@ +"""Load question answering with sources chains.""" +from typing import Any, Mapping, Optional, Protocol + +from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain +from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain +from langchain.chains.combine_documents.refine import RefineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.llm import LLMChain +from langchain.chains.qa_with_sources import ( + map_reduce_prompt, + refine_prompts, + stuff_prompt, +) +from langchain.chains.question_answering import map_rerank_prompt +from langchain.llms.base import BaseLLM +from langchain.prompts.base import BasePromptTemplate + + +class LoadingCallable(Protocol): + """Interface for loading the combine documents chain.""" + + def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: + """Callable to load the combine documents chain.""" + + +def _load_map_rerank_chain( + llm: BaseLLM, + prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, + verbose: bool = False, + document_variable_name: str = "context", + rank_key: str = "score", + answer_key: str = "answer", + **kwargs: Any, +) -> MapRerankDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + return MapRerankDocumentsChain( + llm_chain=llm_chain, + rank_key=rank_key, + answer_key=answer_key, + document_variable_name=document_variable_name, + **kwargs, + ) + + +def _load_stuff_chain( + llm: BaseLLM, + prompt: BasePromptTemplate = stuff_prompt.PROMPT, + document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT, + document_variable_name: str = "summaries", + verbose: Optional[bool] = None, + **kwargs: Any, +) -> StuffDocumentsChain: + llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) + return StuffDocumentsChain( + llm_chain=llm_chain, + document_variable_name=document_variable_name, + document_prompt=document_prompt, + verbose=verbose, + **kwargs, + ) + + +def _load_map_reduce_chain( + llm: BaseLLM, + question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, + combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, + document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT, + combine_document_variable_name: str = "summaries", + map_reduce_document_variable_name: str = "context", + collapse_prompt: Optional[BasePromptTemplate] = None, + reduce_llm: Optional[BaseLLM] = None, + collapse_llm: Optional[BaseLLM] = None, + verbose: Optional[bool] = None, + **kwargs: Any, +) -> MapReduceDocumentsChain: + map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) + _reduce_llm = reduce_llm or llm + reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) + combine_document_chain = StuffDocumentsChain( + llm_chain=reduce_chain, + document_variable_name=combine_document_variable_name, + document_prompt=document_prompt, + verbose=verbose, + ) + if collapse_prompt is None: + collapse_chain = None + if collapse_llm is not None: + raise ValueError( + "collapse_llm provided, but collapse_prompt was not: please " + "provide one or stop providing collapse_llm." + ) + else: + _collapse_llm = collapse_llm or llm + collapse_chain = StuffDocumentsChain( + llm_chain=LLMChain( + llm=_collapse_llm, + prompt=collapse_prompt, + verbose=verbose, + ), + document_variable_name=combine_document_variable_name, + document_prompt=document_prompt, + ) + return MapReduceDocumentsChain( + llm_chain=map_chain, + combine_document_chain=combine_document_chain, + document_variable_name=map_reduce_document_variable_name, + collapse_document_chain=collapse_chain, + verbose=verbose, + **kwargs, + ) + + +def _load_refine_chain( + llm: BaseLLM, + question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, + refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, + document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT, + document_variable_name: str = "context_str", + initial_response_name: str = "existing_answer", + refine_llm: Optional[BaseLLM] = None, + verbose: Optional[bool] = None, + **kwargs: Any, +) -> RefineDocumentsChain: + initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) + _refine_llm = refine_llm or llm + refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) + return RefineDocumentsChain( + initial_llm_chain=initial_chain, + refine_llm_chain=refine_chain, + document_variable_name=document_variable_name, + initial_response_name=initial_response_name, + document_prompt=document_prompt, + verbose=verbose, + **kwargs, + ) + + +def load_qa_with_sources_chain( + llm: BaseLLM, + chain_type: str = "stuff", + verbose: Optional[bool] = None, + **kwargs: Any, +) -> BaseCombineDocumentsChain: + """Load question answering with sources chain. + + Args: + llm: Language Model to use in the chain. + chain_type: Type of document combining chain to use. Should be one of "stuff", + "map_reduce", and "refine". + verbose: Whether chains should be run in verbose mode or not. Note that this + applies to all chains that make up the final chain. + + Returns: + A chain to use for question answering with sources. + """ + loader_mapping: Mapping[str, LoadingCallable] = { + "stuff": _load_stuff_chain, + "map_reduce": _load_map_reduce_chain, + "refine": _load_refine_chain, + "map_rerank": _load_map_rerank_chain, + } + if chain_type not in loader_mapping: + raise ValueError( + f"Got unsupported chain type: {chain_type}. " + f"Should be one of {loader_mapping.keys()}" + ) + _func: LoadingCallable = loader_mapping[chain_type] + return _func(llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index c8401a73d3f..57ae63744f7 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -9,6 +9,7 @@ from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain +from langchain.chains.question_answering import load_qa_chain from langchain.chains.vector_db_qa.prompt import PROMPT from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate @@ -104,19 +105,24 @@ class VectorDBQA(Chain, BaseModel): return cls(combine_documents_chain=combine_documents_chain, **kwargs) - def _call( - self, - inputs: Dict[str, str], - ) -> Dict[str, Any]: + @classmethod + def from_chain_type( + cls, llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any + ) -> VectorDBQA: + """Load chain from chain type.""" + combine_documents_chain = load_qa_chain(llm, chain_type=chain_type) + return cls(combine_documents_chain=combine_documents_chain, **kwargs) + + def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: """Run similarity search and llm on input query. - If inputs contains 'return_source_documents' as 'True', returns + If chain has 'return_source_documents' as 'True', returns the retrieved documents as well under the key 'source_documents'. Example: .. code-block:: python - res = vectordbqa({'query': 'This is my query', 'return_source_documents': True}) + res = vectordbqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ question = inputs[self.input_key]