Compare commits

...

1 Commits

Author SHA1 Message Date
Dev 2049
4f5eee1d45 refac 2023-05-16 12:31:51 -07:00
6 changed files with 234 additions and 104 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
from langchain import PromptTemplate
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
@@ -15,10 +16,13 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
def get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
@@ -36,17 +40,9 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
input_key: str = "input_documents" #: :meta private:
input_documents_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
@@ -78,9 +74,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
docs = inputs[self.input_documents_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys = {k: v for k, v in inputs.items() if k != self.input_documents_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
@@ -93,9 +89,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
docs = inputs[self.input_documents_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys = {k: v for k, v in inputs.items() if k != self.input_documents_key}
output, extra_return_dict = await self.acombine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
@@ -136,7 +132,7 @@ class AnalyzeDocumentChain(Chain):
docs = self.text_splitter.create_documents([document])
# Other keys are assumed to be needed for LLM prediction
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs
other_keys[self.combine_docs_chain.input_documents_key] = docs
return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
)

View File

@@ -4,10 +4,16 @@ from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from pydantic import Extra, root_validator
from mypy_extensions import KwArg
from pydantic import Extra, Field, root_validator
from langchain import BasePromptTemplate
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
get_default_document_prompt,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
@@ -46,10 +52,8 @@ def _split_list_of_docs(
def _collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
combine_docs_result: str,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
@@ -57,7 +61,7 @@ def _collapse_docs(
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
return Document(page_content=combine_docs_result, metadata=combined_metadata)
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@@ -65,6 +69,10 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
llm_chain: LLMChain
"""Chain to apply to each document individually."""
document_prompt: BasePromptTemplate = Field(
default_factory=get_default_document_prompt
)
"""Prompt to use to format each document."""
combine_document_chain: BaseCombineDocumentsChain
"""Chain to use to combine results of applying llm_chain to documents."""
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
@@ -76,9 +84,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""
@property
def input_keys(self) -> List[str]:
"""Return input keys."""
all_keys = set(
[self.input_documents_key, "token_max"]
+ self.llm_chain.input_keys
+ self._collapse_chain.input_keys
)
internal_keys = [
self.document_variable_name,
self.combine_document_chain.input_documents_key,
self._collapse_chain.input_documents_key,
]
return list(set(all_keys).difference(internal_keys))
@property
def output_keys(self) -> List[str]:
"""Expect input key.
"""Return output keys.
:meta private:
"""
@@ -129,6 +152,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
else:
return self.combine_document_chain
def _get_llm_chain_inputs(self, docs: List[Document], **kwargs: Any) -> List[dict]:
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
_kwargs = {k: v for k, v in kwargs.items() if k in self.llm_chain.input_keys}
return [{self.document_variable_name: _doc, **_kwargs} for _doc in doc_strings]
def combine_docs(
self,
docs: List[Document],
@@ -141,35 +171,41 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
inputs = self._get_llm_chain_inputs(docs, **kwargs)
# FYI - this is parallelized and so it is fast.
results = self.llm_chain.apply(inputs, callbacks=callbacks)
return self._process_results(
results, docs, token_max, callbacks=callbacks, **kwargs
)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
inputs = self._get_llm_chain_inputs(docs, **kwargs)
# FYI - this is parallelized and so it is fast.
results = await self.llm_chain.aapply(inputs, callbacks=callbacks)
return self._process_results(
results, docs, token_max, callbacks=callbacks, **kwargs
)
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
@property
def _length_func(self) -> Callable[[List[Document], KwArg(Any)], Optional[int]]:
return self.combine_document_chain.prompt_length # type: ignore
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
token_max: int,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
@@ -179,33 +215,28 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(results)
]
length_func = self.combine_document_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
num_tokens = self._length_func(result_docs, **kwargs)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
result_docs, self._length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = self.combine_document_chain.prompt_length(
result_docs, **kwargs
)
inputs = {self._collapse_chain.input_documents_key: docs, **kwargs}
result = self._collapse_chain.run(callbacks=callbacks, **inputs)
result_docs.append(_collapse_docs(docs, result))
num_tokens = self._length_func(result_docs, **kwargs)
if self.return_intermediate_steps:
_results = [r[self.llm_chain.output_key] for r in results]
extra_return_dict = {"intermediate_steps": _results}
else:
extra_return_dict = {}
output = self.combine_document_chain.run(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
inputs = {
self.combine_document_chain.input_documents_key: result_docs,
**kwargs,
}
output = self.combine_document_chain.run(callbacks=callbacks, **inputs)
return output, extra_return_dict
@property

View File

@@ -4,10 +4,15 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from pydantic import Extra, root_validator
from pydantic import Extra, Field, root_validator
from langchain import BasePromptTemplate
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
get_default_document_prompt,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
@@ -18,6 +23,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
llm_chain: LLMChain
"""Chain to apply to each document individually."""
document_prompt: BasePromptTemplate = Field(
default_factory=get_default_document_prompt
)
"""Prompt to use to format each document."""
document_variable_name: str
"""The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided."""
@@ -34,9 +43,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return output keys."""
all_keys = set([self.input_documents_key] + self.llm_chain.input_keys)
internal_keys = [self.document_variable_name]
return list(all_keys.difference(internal_keys))
@property
def output_keys(self) -> List[str]:
"""Expect input key.
"""Return output keys.
:meta private:
"""
@@ -90,6 +106,13 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
)
return values
def _get_llm_chain_inputs(self, docs: List[Document], **kwargs: Any) -> List[dict]:
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
_kwargs = {k: v for k, v in kwargs.items() if k in self.llm_chain.input_keys}
return [{self.document_variable_name: _doc, **_kwargs} for _doc in doc_strings]
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
@@ -97,11 +120,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reranking the results.
"""
results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
inputs = self._get_llm_chain_inputs(docs, **kwargs)
# FYI - this is parallelized and so it is fast.
results = self.llm_chain.apply_and_parse(inputs, callbacks=callbacks)
return self._process_results(docs, results)
async def acombine_docs(
@@ -111,11 +132,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reranking the results.
"""
results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
inputs = self._get_llm_chain_inputs(docs, **kwargs)
# FYI - this is parallelized and so it is fast.
results = await self.llm_chain.aapply_and_parse(inputs, callbacks=callbacks)
return self._process_results(docs, results)
def _process_results(

View File

@@ -6,10 +6,12 @@ from typing import Any, Dict, List, Tuple
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
get_default_document_prompt,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
@@ -17,10 +19,6 @@ from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents by doing a first pass and then refining on more documents."""
@@ -30,19 +28,60 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"""LLM chain to use when refining."""
document_variable_name: str
"""The variable name in the initial_llm_chain to put the documents in.
If only one variable in the initial_llm_chain, this need not be provided."""
If only one variable in the initial_llm_chain, this doesn't need to be specified.
"""
initial_response_name: str
"""The variable name to format the initial response in when refining."""
"""The variable name to format the initial response in when refining.
If only two variables are in the refine_llm_chain, this doesn't need to be
specified.
"""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
default_factory=get_default_document_prompt
)
"""Prompt to use to format each document."""
return_intermediate_steps: bool = False
"""Return the results of the refine steps in the output."""
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
question_prompt: BasePromptTemplate,
refine_prompt: BasePromptTemplate,
**kwargs: Any,
) -> RefineDocumentsChain:
"""Initialize RefineDocumentsChain from an LLM and two prompts.
Example:
from langchain.chains.combine_documents import RefineDocumentsChain
from langchain.chains.summarize.refine_prompts import PROMPT, REFINE_PROMPT
from langchain.llms import OpenAI
refine_docs_chain = RefineDocumentsChain.from_llm(
OpenAI(), PROMPT, REFINE_PROMPT
)
"""
initial_chain = LLMChain(llm=llm, prompt=question_prompt)
refine_chain = LLMChain(llm=llm, prompt=refine_prompt)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
**kwargs,
)
@property
def input_keys(self) -> List[str]:
"""Return input keys."""
keys = set(
[self.input_documents_key]
+ self.initial_llm_chain.input_keys
+ self.refine_llm_chain.input_keys
).difference([self.document_variable_name, self.initial_response_name])
return list(keys)
@property
def output_keys(self) -> List[str]:
"""Expect input key.
"""Return output keys.
:meta private:
"""
@@ -66,23 +105,36 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return values
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""
def get_default_variable_names(cls, values: Dict) -> Dict:
"""Infer and validate sub-chain input variable names, if not provided."""
initial_inputs = values["initial_llm_chain"].input_keys
if "document_variable_name" not in values:
llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
if len(initial_inputs) == 1:
values["document_variable_name"] = initial_inputs[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain input_variables"
"multiple initial_llm_chain input_variables"
)
else:
llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
if values["document_variable_name"] not in llm_chain_variables:
if values["document_variable_name"] not in initial_inputs:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
f"not found in initial_llm_chain input_keys: {initial_inputs}"
)
refine_inputs = values["refine_llm_chain"].input_keys
if "initial_response_name" not in values:
doc_input = values["document_variable_name"]
if len(refine_inputs) == 2:
init_resp_input = [i for i in refine_inputs if i != doc_input][0]
values["initial_response_name"] = init_resp_input
else:
raise ValueError
else:
if values["initial_response_name"] not in refine_inputs:
raise ValueError(
f"initial_response_name {values['initial_response_name']} was not "
f"found in refine_llm_chain input_keys: {refine_inputs}"
)
return values
@@ -94,8 +146,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
inputs = self._construct_refine_inputs(doc, res)
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
@@ -108,8 +159,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
inputs = self._construct_refine_inputs(doc, res)
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
@@ -121,23 +171,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
extra_return_dict = {}
return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
def _construct_refine_inputs(
self, doc: Document, res: str, **kwargs: Any
) -> Dict[str, Any]:
return {
self.document_variable_name: format_document(doc, self.document_prompt),
self.initial_response_name: res,
**kwargs,
}
def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any
) -> Dict[str, Any]:
base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
base_inputs: dict = {
self.document_variable_name: self.document_prompt.format(**document_info)
return {
self.document_variable_name: format_document(docs[0], self.document_prompt),
**kwargs,
}
inputs = {**base_inputs, **kwargs}
return inputs
@property
def _chain_type(self) -> str:

View File

@@ -1,22 +1,20 @@
"""Chain that combines documents by stuffing into context."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
get_default_document_prompt,
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
class StuffDocumentsChain(BaseCombineDocumentsChain):
@@ -25,7 +23,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
llm_chain: LLMChain
"""LLM wrapper to use after formatting documents."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
default_factory=get_default_document_prompt
)
"""Prompt to use to format each document."""
document_variable_name: str
@@ -34,12 +32,27 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
document_separator: str = "\n\n"
"""The string with which to join the formatted documents"""
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
) -> StuffDocumentsChain:
"""Initialize StuffDocumentsChain from an LLM."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Return output keys."""
all_keys = set([self.input_documents_key] + self.llm_chain.input_keys)
internal_keys = [self.document_variable_name]
return list(all_keys.difference(internal_keys))
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided."""

View File

@@ -30,6 +30,8 @@ class BaseRetrievalQA(Chain):
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
combine_documents_chain_question_key: str = "question"
"""TODO: Find a more general / clean approach."""
class Config:
"""Configuration for this pydantic object."""
@@ -62,17 +64,20 @@ class BaseRetrievalQA(Chain):
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
document_variable_name: Optional[str] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
_document_variable_name = document_variable_name or "context"
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_variable_name=document_variable_name,
document_prompt=document_prompt,
)
@@ -97,6 +102,21 @@ class BaseRetrievalQA(Chain):
def _get_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
def _get_combine_documents_chain_inputs(
self, docs: List[Document], question: str, **kwargs: Any
) -> Dict[str, Any]:
_kwargs = {
k: v
for k, v in kwargs.items()
if k in self.combine_documents_chain.input_keys
}
inputs = {
self.combine_documents_chain.input_documents_key: docs,
self.combine_documents_chain_question_key: question,
**_kwargs,
}
return inputs
def _call(
self,
inputs: Dict[str, Any],
@@ -114,11 +134,12 @@ class BaseRetrievalQA(Chain):
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
question = inputs[self.input_key]
docs = self._get_docs(question)
_inputs = self._get_combine_documents_chain_inputs(docs, question, **inputs)
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
callbacks=_run_manager.get_child(), **_inputs
)
if self.return_source_documents:
@@ -147,11 +168,12 @@ class BaseRetrievalQA(Chain):
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
question = inputs[self.input_key]
docs = await self._aget_docs(question)
_inputs = self._get_combine_documents_chain_inputs(docs, question, **inputs)
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
callbacks=_run_manager.get_child(), **_inputs
)
if self.return_source_documents: