mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
1 Commits
langchain=
...
dev2049/co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f5eee1d45 |
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
@@ -15,10 +16,13 @@ from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
|
||||
def get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
@@ -36,17 +40,9 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
input_documents_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
@@ -78,9 +74,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
docs = inputs[self.input_documents_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_documents_key}
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
@@ -93,9 +89,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
docs = inputs[self.input_documents_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_documents_key}
|
||||
output, extra_return_dict = await self.acombine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
@@ -136,7 +132,7 @@ class AnalyzeDocumentChain(Chain):
|
||||
docs = self.text_splitter.create_documents([document])
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys[self.combine_docs_chain.input_key] = docs
|
||||
other_keys[self.combine_docs_chain.input_documents_key] = docs
|
||||
return self.combine_docs_chain(
|
||||
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
@@ -4,10 +4,16 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from mypy_extensions import KwArg
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
get_default_document_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@@ -46,10 +52,8 @@ def _split_list_of_docs(
|
||||
|
||||
def _collapse_docs(
|
||||
docs: List[Document],
|
||||
combine_document_func: CombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
combine_docs_result: str,
|
||||
) -> Document:
|
||||
result = combine_document_func(docs, **kwargs)
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
@@ -57,7 +61,7 @@ def _collapse_docs(
|
||||
combined_metadata[k] += f", {v}"
|
||||
else:
|
||||
combined_metadata[k] = str(v)
|
||||
return Document(page_content=result, metadata=combined_metadata)
|
||||
return Document(page_content=combine_docs_result, metadata=combined_metadata)
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
@@ -65,6 +69,10 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
combine_document_chain: BaseCombineDocumentsChain
|
||||
"""Chain to use to combine results of applying llm_chain to documents."""
|
||||
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
||||
@@ -76,9 +84,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return the results of the map steps in the output."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return input keys."""
|
||||
all_keys = set(
|
||||
[self.input_documents_key, "token_max"]
|
||||
+ self.llm_chain.input_keys
|
||||
+ self._collapse_chain.input_keys
|
||||
)
|
||||
internal_keys = [
|
||||
self.document_variable_name,
|
||||
self.combine_document_chain.input_documents_key,
|
||||
self._collapse_chain.input_documents_key,
|
||||
]
|
||||
return list(set(all_keys).difference(internal_keys))
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
"""Return output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@@ -129,6 +152,13 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
else:
|
||||
return self.combine_document_chain
|
||||
|
||||
def _get_llm_chain_inputs(self, docs: List[Document], **kwargs: Any) -> List[dict]:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||
# Join the documents together to put them in the prompt.
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in self.llm_chain.input_keys}
|
||||
return [{self.document_variable_name: _doc, **_kwargs} for _doc in doc_strings]
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
@@ -141,35 +171,41 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
This reducing can be done recursively if needed (if there are many documents).
|
||||
"""
|
||||
results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
inputs = self._get_llm_chain_inputs(docs, **kwargs)
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
results = self.llm_chain.apply(inputs, callbacks=callbacks)
|
||||
return self._process_results(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
This reducing can be done recursively if needed (if there are many documents).
|
||||
"""
|
||||
results = await self.llm_chain.aapply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
inputs = self._get_llm_chain_inputs(docs, **kwargs)
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
results = await self.llm_chain.aapply(inputs, callbacks=callbacks)
|
||||
return self._process_results(
|
||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
|
||||
|
||||
@property
|
||||
def _length_func(self) -> Callable[[List[Document], KwArg(Any)], Optional[int]]:
|
||||
return self.combine_document_chain.prompt_length # type: ignore
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
results: List[Dict],
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: int,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
@@ -179,33 +215,28 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
# This uses metadata from the docs, and the textual results from `results`
|
||||
for i, r in enumerate(results)
|
||||
]
|
||||
length_func = self.combine_document_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
num_tokens = self._length_func(result_docs, **kwargs)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
result_docs, self._length_func, token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||
result_docs.append(new_doc)
|
||||
num_tokens = self.combine_document_chain.prompt_length(
|
||||
result_docs, **kwargs
|
||||
)
|
||||
inputs = {self._collapse_chain.input_documents_key: docs, **kwargs}
|
||||
result = self._collapse_chain.run(callbacks=callbacks, **inputs)
|
||||
result_docs.append(_collapse_docs(docs, result))
|
||||
num_tokens = self._length_func(result_docs, **kwargs)
|
||||
if self.return_intermediate_steps:
|
||||
_results = [r[self.llm_chain.output_key] for r in results]
|
||||
extra_return_dict = {"intermediate_steps": _results}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
output = self.combine_document_chain.run(
|
||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
inputs = {
|
||||
self.combine_document_chain.input_documents_key: result_docs,
|
||||
**kwargs,
|
||||
}
|
||||
output = self.combine_document_chain.run(callbacks=callbacks, **inputs)
|
||||
return output, extra_return_dict
|
||||
|
||||
@property
|
||||
|
||||
@@ -4,10 +4,15 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
get_default_document_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
@@ -18,6 +23,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""Chain to apply to each document individually."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
@@ -34,9 +43,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return output keys."""
|
||||
all_keys = set([self.input_documents_key] + self.llm_chain.input_keys)
|
||||
internal_keys = [self.document_variable_name]
|
||||
return list(all_keys.difference(internal_keys))
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
"""Return output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@@ -90,6 +106,13 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_llm_chain_inputs(self, docs: List[Document], **kwargs: Any) -> List[dict]:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||
# Join the documents together to put them in the prompt.
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k in self.llm_chain.input_keys}
|
||||
return [{self.document_variable_name: _doc, **_kwargs} for _doc in doc_strings]
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
@@ -97,11 +120,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = self.llm_chain.apply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
inputs = self._get_llm_chain_inputs(docs, **kwargs)
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
results = self.llm_chain.apply_and_parse(inputs, callbacks=callbacks)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
async def acombine_docs(
|
||||
@@ -111,11 +132,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
"""
|
||||
results = await self.llm_chain.aapply_and_parse(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
inputs = self._get_llm_chain_inputs(docs, **kwargs)
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
results = await self.llm_chain.aapply_and_parse(inputs, callbacks=callbacks)
|
||||
return self._process_results(docs, results)
|
||||
|
||||
def _process_results(
|
||||
|
||||
@@ -6,10 +6,12 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
get_default_document_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
@@ -17,10 +19,6 @@ from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
||||
|
||||
|
||||
class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combine documents by doing a first pass and then refining on more documents."""
|
||||
|
||||
@@ -30,19 +28,60 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""LLM chain to use when refining."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the initial_llm_chain to put the documents in.
|
||||
If only one variable in the initial_llm_chain, this need not be provided."""
|
||||
If only one variable in the initial_llm_chain, this doesn't need to be specified.
|
||||
"""
|
||||
initial_response_name: str
|
||||
"""The variable name to format the initial response in when refining."""
|
||||
"""The variable name to format the initial response in when refining.
|
||||
If only two variables are in the refine_llm_chain, this doesn't need to be
|
||||
specified.
|
||||
"""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
default_factory=get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return the results of the refine steps in the output."""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: BasePromptTemplate,
|
||||
refine_prompt: BasePromptTemplate,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
"""Initialize RefineDocumentsChain from an LLM and two prompts.
|
||||
|
||||
Example:
|
||||
from langchain.chains.combine_documents import RefineDocumentsChain
|
||||
from langchain.chains.summarize.refine_prompts import PROMPT, REFINE_PROMPT
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
refine_docs_chain = RefineDocumentsChain.from_llm(
|
||||
OpenAI(), PROMPT, REFINE_PROMPT
|
||||
)
|
||||
"""
|
||||
initial_chain = LLMChain(llm=llm, prompt=question_prompt)
|
||||
refine_chain = LLMChain(llm=llm, prompt=refine_prompt)
|
||||
return RefineDocumentsChain(
|
||||
initial_llm_chain=initial_chain,
|
||||
refine_llm_chain=refine_chain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return input keys."""
|
||||
keys = set(
|
||||
[self.input_documents_key]
|
||||
+ self.initial_llm_chain.input_keys
|
||||
+ self.refine_llm_chain.input_keys
|
||||
).difference([self.document_variable_name, self.initial_response_name])
|
||||
return list(keys)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
"""Return output keys.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
@@ -66,23 +105,36 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
def get_default_variable_names(cls, values: Dict) -> Dict:
|
||||
"""Infer and validate sub-chain input variable names, if not provided."""
|
||||
initial_inputs = values["initial_llm_chain"].input_keys
|
||||
if "document_variable_name" not in values:
|
||||
llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
if len(initial_inputs) == 1:
|
||||
values["document_variable_name"] = initial_inputs[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain input_variables"
|
||||
"multiple initial_llm_chain input_variables"
|
||||
)
|
||||
else:
|
||||
llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
if values["document_variable_name"] not in initial_inputs:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
f"not found in initial_llm_chain input_keys: {initial_inputs}"
|
||||
)
|
||||
refine_inputs = values["refine_llm_chain"].input_keys
|
||||
if "initial_response_name" not in values:
|
||||
doc_input = values["document_variable_name"]
|
||||
if len(refine_inputs) == 2:
|
||||
init_resp_input = [i for i in refine_inputs if i != doc_input][0]
|
||||
values["initial_response_name"] = init_resp_input
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
if values["initial_response_name"] not in refine_inputs:
|
||||
raise ValueError(
|
||||
f"initial_response_name {values['initial_response_name']} was not "
|
||||
f"found in refine_llm_chain input_keys: {refine_inputs}"
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -94,8 +146,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
inputs = self._construct_refine_inputs(doc, res)
|
||||
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
@@ -108,8 +159,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
refine_steps = [res]
|
||||
for doc in docs[1:]:
|
||||
base_inputs = self._construct_refine_inputs(doc, res)
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
inputs = self._construct_refine_inputs(doc, res)
|
||||
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
@@ -121,23 +171,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
extra_return_dict = {}
|
||||
return res, extra_return_dict
|
||||
|
||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||
def _construct_refine_inputs(
|
||||
self, doc: Document, res: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
self.document_variable_name: format_document(doc, self.document_prompt),
|
||||
self.initial_response_name: res,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _construct_initial_inputs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
base_info = {"page_content": docs[0].page_content}
|
||||
base_info.update(docs[0].metadata)
|
||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||
base_inputs: dict = {
|
||||
self.document_variable_name: self.document_prompt.format(**document_info)
|
||||
return {
|
||||
self.document_variable_name: format_document(docs[0], self.document_prompt),
|
||||
**kwargs,
|
||||
}
|
||||
inputs = {**base_inputs, **kwargs}
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
BaseCombineDocumentsChain,
|
||||
format_document,
|
||||
get_default_document_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
@@ -25,7 +23,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
default_factory=get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
@@ -34,12 +32,27 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
document_separator: str = "\n\n"
|
||||
"""The string with which to join the formatted documents"""
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||
) -> StuffDocumentsChain:
|
||||
"""Initialize StuffDocumentsChain from an LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Return output keys."""
|
||||
all_keys = set([self.input_documents_key] + self.llm_chain.input_keys)
|
||||
internal_keys = [self.document_variable_name]
|
||||
return list(all_keys.difference(internal_keys))
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
|
||||
@@ -30,6 +30,8 @@ class BaseRetrievalQA(Chain):
|
||||
output_key: str = "result" #: :meta private:
|
||||
return_source_documents: bool = False
|
||||
"""Return the source documents."""
|
||||
combine_documents_chain_question_key: str = "question"
|
||||
"""TODO: Find a more general / clean approach."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@@ -62,17 +64,20 @@ class BaseRetrievalQA(Chain):
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
document_variable_name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetrievalQA:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
_document_variable_name = document_variable_name or "context"
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_variable_name=document_variable_name,
|
||||
document_prompt=document_prompt,
|
||||
)
|
||||
|
||||
@@ -97,6 +102,21 @@ class BaseRetrievalQA(Chain):
|
||||
def _get_docs(self, question: str) -> List[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _get_combine_documents_chain_inputs(
|
||||
self, docs: List[Document], question: str, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in self.combine_documents_chain.input_keys
|
||||
}
|
||||
inputs = {
|
||||
self.combine_documents_chain.input_documents_key: docs,
|
||||
self.combine_documents_chain_question_key: question,
|
||||
**_kwargs,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
@@ -114,11 +134,12 @@ class BaseRetrievalQA(Chain):
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
question = inputs[self.input_key]
|
||||
docs = self._get_docs(question)
|
||||
_inputs = self._get_combine_documents_chain_inputs(docs, question, **inputs)
|
||||
answer = self.combine_documents_chain.run(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
callbacks=_run_manager.get_child(), **_inputs
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
@@ -147,11 +168,12 @@ class BaseRetrievalQA(Chain):
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
question = inputs[self.input_key]
|
||||
docs = await self._aget_docs(question)
|
||||
_inputs = self._get_combine_documents_chain_inputs(docs, question, **inputs)
|
||||
answer = await self.combine_documents_chain.arun(
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
callbacks=_run_manager.get_child(), **_inputs
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
|
||||
Reference in New Issue
Block a user