mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
Docs combine document chain (#6994)
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
|
|||||||
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
from langchain.chains.combine_documents.base import AnalyzeDocumentChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||||
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
||||||
@@ -111,4 +112,5 @@ __all__ = [
|
|||||||
"MapRerankDocumentsChain",
|
"MapRerankDocumentsChain",
|
||||||
"MapReduceDocumentsChain",
|
"MapReduceDocumentsChain",
|
||||||
"RefineDocumentsChain",
|
"RefineDocumentsChain",
|
||||||
|
"ReduceDocumentsChain",
|
||||||
]
|
]
|
||||||
|
@@ -11,30 +11,20 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
|
||||||
"""Format a document into a string based on a prompt template."""
|
|
||||||
base_info = {"page_content": doc.page_content}
|
|
||||||
base_info.update(doc.metadata)
|
|
||||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
|
||||||
if len(missing_metadata) > 0:
|
|
||||||
required_metadata = [
|
|
||||||
iv for iv in prompt.input_variables if iv != "page_content"
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
f"Document prompt requires documents to have metadata variables: "
|
|
||||||
f"{required_metadata}. Received document with missing metadata: "
|
|
||||||
f"{list(missing_metadata)}."
|
|
||||||
)
|
|
||||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
|
||||||
return prompt.format(**document_info)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseCombineDocumentsChain(Chain, ABC):
|
class BaseCombineDocumentsChain(Chain, ABC):
|
||||||
"""Base interface for chains combining documents."""
|
"""Base interface for chains combining documents.
|
||||||
|
|
||||||
|
Subclasses of this chain deal with combining documents in a variety of
|
||||||
|
ways. This base class exists to add some uniformity in the interface these types
|
||||||
|
of chains should expose. Namely, they expect an input key related to the documents
|
||||||
|
to use (default `input_documents`), and then also expose a method to calculate
|
||||||
|
the length of a prompt from documents (useful for outside callers to use to
|
||||||
|
determine whether it's safe to pass a list of documents into this chain or whether
|
||||||
|
that will longer than the context length).
|
||||||
|
"""
|
||||||
|
|
||||||
input_key: str = "input_documents" #: :meta private:
|
input_key: str = "input_documents" #: :meta private:
|
||||||
output_key: str = "output_text" #: :meta private:
|
output_key: str = "output_text" #: :meta private:
|
||||||
@@ -58,25 +48,57 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||||
"""Return the prompt length given the documents passed in.
|
"""Return the prompt length given the documents passed in.
|
||||||
|
|
||||||
Returns None if the method does not depend on the prompt length.
|
This can be used by a caller to determine whether passing in a list
|
||||||
|
of documents would exceed a certain prompt length. This useful when
|
||||||
|
trying to ensure that the size of a prompt remains below a certain
|
||||||
|
context limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List[Document], a list of documents to use to calculate the
|
||||||
|
total prompt length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns None if the method does not depend on the prompt length,
|
||||||
|
otherwise the length of the prompt in tokens.
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||||
"""Combine documents into a single string."""
|
"""Combine documents into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List[Document], the documents to combine
|
||||||
|
**kwargs: Other parameters to use in combining documents, often
|
||||||
|
other inputs to the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], **kwargs: Any
|
self, docs: List[Document], **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Combine documents into a single string asynchronously."""
|
"""Combine documents into a single string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List[Document], the documents to combine
|
||||||
|
**kwargs: Other parameters to use in combining documents, often
|
||||||
|
other inputs to the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, List[Document]],
|
inputs: Dict[str, List[Document]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
docs = inputs[self.input_key]
|
docs = inputs[self.input_key]
|
||||||
# Other keys are assumed to be needed for LLM prediction
|
# Other keys are assumed to be needed for LLM prediction
|
||||||
@@ -92,6 +114,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
inputs: Dict[str, List[Document]],
|
inputs: Dict[str, List[Document]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
docs = inputs[self.input_key]
|
docs = inputs[self.input_key]
|
||||||
# Other keys are assumed to be needed for LLM prediction
|
# Other keys are assumed to be needed for LLM prediction
|
||||||
@@ -104,7 +127,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class AnalyzeDocumentChain(Chain):
|
class AnalyzeDocumentChain(Chain):
|
||||||
"""Chain that splits documents, then analyzes it in pieces."""
|
"""Chain that splits documents, then analyzes it in pieces.
|
||||||
|
|
||||||
|
This chain is parameterized by a TextSplitter and a CombineDocumentsChain.
|
||||||
|
This chain takes a single document as input, and then splits it up into chunks
|
||||||
|
and then passes those chucks to the CombineDocumentsChain.
|
||||||
|
"""
|
||||||
|
|
||||||
input_key: str = "input_document" #: :meta private:
|
input_key: str = "input_document" #: :meta private:
|
||||||
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
|
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
|
||||||
@@ -131,6 +159,7 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
inputs: Dict[str, str],
|
inputs: Dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
|
"""Split document into chunks and pass to CombineDocumentsChain."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
document = inputs[self.input_key]
|
document = inputs[self.input_key]
|
||||||
docs = self.text_splitter.create_documents([document])
|
docs = self.text_splitter.create_documents([document])
|
||||||
|
@@ -2,74 +2,97 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
class CombineDocsProtocol(Protocol):
|
|
||||||
"""Interface for the combine_docs method."""
|
|
||||||
|
|
||||||
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
|
||||||
"""Interface for the combine_docs method."""
|
|
||||||
|
|
||||||
|
|
||||||
def _split_list_of_docs(
|
|
||||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
|
||||||
) -> List[List[Document]]:
|
|
||||||
new_result_doc_list = []
|
|
||||||
_sub_result_docs = []
|
|
||||||
for doc in docs:
|
|
||||||
_sub_result_docs.append(doc)
|
|
||||||
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
|
||||||
if _num_tokens > token_max:
|
|
||||||
if len(_sub_result_docs) == 1:
|
|
||||||
raise ValueError(
|
|
||||||
"A single document was longer than the context length,"
|
|
||||||
" we cannot handle this."
|
|
||||||
)
|
|
||||||
if len(_sub_result_docs) == 2:
|
|
||||||
raise ValueError(
|
|
||||||
"A single document was so long it could not be combined "
|
|
||||||
"with another document, we cannot handle this."
|
|
||||||
)
|
|
||||||
new_result_doc_list.append(_sub_result_docs[:-1])
|
|
||||||
_sub_result_docs = _sub_result_docs[-1:]
|
|
||||||
new_result_doc_list.append(_sub_result_docs)
|
|
||||||
return new_result_doc_list
|
|
||||||
|
|
||||||
|
|
||||||
def _collapse_docs(
|
|
||||||
docs: List[Document],
|
|
||||||
combine_document_func: CombineDocsProtocol,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Document:
|
|
||||||
result = combine_document_func(docs, **kwargs)
|
|
||||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
|
||||||
for doc in docs[1:]:
|
|
||||||
for k, v in doc.metadata.items():
|
|
||||||
if k in combined_metadata:
|
|
||||||
combined_metadata[k] += f", {v}"
|
|
||||||
else:
|
|
||||||
combined_metadata[k] = str(v)
|
|
||||||
return Document(page_content=result, metadata=combined_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||||
"""Combining documents by mapping a chain over them, then combining results."""
|
"""Combining documents by mapping a chain over them, then combining results.
|
||||||
|
|
||||||
|
We first call `llm_chain` on each document individually, passing in the
|
||||||
|
`page_content` and any other kwargs. This is the `map` step.
|
||||||
|
|
||||||
|
We then process the results of that `map` step in a `reduce` step. This should
|
||||||
|
likely be a ReduceDocumentsChain.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import (
|
||||||
|
StuffDocumentsChain,
|
||||||
|
LLMChain,
|
||||||
|
ReduceDocumentsChain,
|
||||||
|
MapReduceDocumentsChain,
|
||||||
|
)
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
# This controls how each document will be formatted. Specifically,
|
||||||
|
# it will be passed to `format_document` - see that function for more
|
||||||
|
# details.
|
||||||
|
document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"],
|
||||||
|
template="{page_content}"
|
||||||
|
)
|
||||||
|
document_variable_name = "context"
|
||||||
|
llm = OpenAI()
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Summarize this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
# We now define how to combine these summaries
|
||||||
|
reduce_prompt = PromptTemplate.from_template(
|
||||||
|
"Combine these summaries: {context}"
|
||||||
|
)
|
||||||
|
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
|
||||||
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=reduce_llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
)
|
||||||
|
chain = MapReduceDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
|
)
|
||||||
|
# If we wanted to, we could also pass in collapse_documents_chain
|
||||||
|
# which is specifically aimed at collapsing documents BEFORE
|
||||||
|
# the final call.
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Collapse this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
collapse_documents_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_documents_chain,
|
||||||
|
)
|
||||||
|
chain = MapReduceDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
"""Chain to apply to each document individually."""
|
"""Chain to apply to each document individually."""
|
||||||
combine_document_chain: BaseCombineDocumentsChain
|
reduce_documents_chain: BaseCombineDocumentsChain
|
||||||
"""Chain to use to combine results of applying llm_chain to documents."""
|
"""Chain to use to reduce the results of applying `llm_chain` to each doc.
|
||||||
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
This typically either a ReduceDocumentChain or StuffDocumentChain."""
|
||||||
"""Chain to use to collapse intermediary results if needed.
|
|
||||||
If None, will use the combine_document_chain."""
|
|
||||||
document_variable_name: str
|
document_variable_name: str
|
||||||
"""The variable name in the llm_chain to put the documents in.
|
"""The variable name in the llm_chain to put the documents in.
|
||||||
If only one variable in the llm_chain, this need not be provided."""
|
If only one variable in the llm_chain, this need not be provided."""
|
||||||
@@ -93,6 +116,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def get_reduce_chain(cls, values: Dict) -> Dict:
|
||||||
|
"""For backwards compatibility."""
|
||||||
|
if "combine_document_chain" in values:
|
||||||
|
if "reduce_documents_chain" in values:
|
||||||
|
raise ValueError(
|
||||||
|
"Both `reduce_documents_chain` and `combine_document_chain` "
|
||||||
|
"cannot be provided at the same time. `combine_document_chain` "
|
||||||
|
"is deprecated, please only provide `reduce_documents_chain`"
|
||||||
|
)
|
||||||
|
combine_chain = values["combine_document_chain"]
|
||||||
|
collapse_chain = values.get("collapse_document_chain")
|
||||||
|
reduce_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_chain,
|
||||||
|
collapse_documents_chain=collapse_chain,
|
||||||
|
)
|
||||||
|
values["reduce_documents_chain"] = reduce_chain
|
||||||
|
del values["combine_document_chain"]
|
||||||
|
if "collapse_document_chain" in values:
|
||||||
|
del values["collapse_document_chain"]
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
|
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
|
||||||
"""For backwards compatibility."""
|
"""For backwards compatibility."""
|
||||||
@@ -123,11 +169,31 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
def collapse_document_chain(self) -> BaseCombineDocumentsChain:
|
||||||
if self.collapse_document_chain is not None:
|
"""Kept for backward compatibility."""
|
||||||
return self.collapse_document_chain
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
|
if self.reduce_documents_chain.collapse_documents_chain:
|
||||||
|
return self.reduce_documents_chain.collapse_documents_chain
|
||||||
|
else:
|
||||||
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
else:
|
else:
|
||||||
return self.combine_document_chain
|
raise ValueError(
|
||||||
|
f"`reduce_documents_chain` is of type "
|
||||||
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
|
f"this attribute."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def combine_document_chain(self) -> BaseCombineDocumentsChain:
|
||||||
|
"""Kept for backward compatibility."""
|
||||||
|
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
|
||||||
|
return self.reduce_documents_chain.combine_documents_chain
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`reduce_documents_chain` is of type "
|
||||||
|
f"{type(self.reduce_documents_chain)} so it does not have "
|
||||||
|
f"this attribute."
|
||||||
|
)
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
@@ -141,14 +207,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
Combine by mapping first chain over all documents, then reducing the results.
|
Combine by mapping first chain over all documents, then reducing the results.
|
||||||
This reducing can be done recursively if needed (if there are many documents).
|
This reducing can be done recursively if needed (if there are many documents).
|
||||||
"""
|
"""
|
||||||
results = self.llm_chain.apply(
|
map_results = self.llm_chain.apply(
|
||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
return self._process_results(
|
question_result_key = self.llm_chain.output_key
|
||||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
result_docs = [
|
||||||
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||||
|
# This uses metadata from the docs, and the textual results from `results`
|
||||||
|
for i, r in enumerate(map_results)
|
||||||
|
]
|
||||||
|
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||||
|
result_docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
intermediate_steps = [r[question_result_key] for r in map_results]
|
||||||
|
extra_return_dict["intermediate_steps"] = intermediate_steps
|
||||||
|
return result, extra_return_dict
|
||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
@@ -158,83 +234,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
Combine by mapping first chain over all documents, then reducing the results.
|
Combine by mapping first chain over all documents, then reducing the results.
|
||||||
This reducing can be done recursively if needed (if there are many documents).
|
This reducing can be done recursively if needed (if there are many documents).
|
||||||
"""
|
"""
|
||||||
results = await self.llm_chain.aapply(
|
map_results = await self.llm_chain.aapply(
|
||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
return await self._aprocess_results(
|
|
||||||
results, docs, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_results_common(
|
|
||||||
self,
|
|
||||||
results: List[Dict],
|
|
||||||
docs: List[Document],
|
|
||||||
token_max: int = 3000,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Tuple[List[Document], dict]:
|
|
||||||
question_result_key = self.llm_chain.output_key
|
question_result_key = self.llm_chain.output_key
|
||||||
result_docs = [
|
result_docs = [
|
||||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||||
# This uses metadata from the docs, and the textual results from `results`
|
# This uses metadata from the docs, and the textual results from `results`
|
||||||
for i, r in enumerate(results)
|
for i, r in enumerate(map_results)
|
||||||
]
|
]
|
||||||
length_func = self.combine_document_chain.prompt_length
|
result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
result_docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
|
||||||
return self._collapse_chain.run(
|
|
||||||
input_documents=docs, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
while num_tokens is not None and num_tokens > token_max:
|
|
||||||
new_result_doc_list = _split_list_of_docs(
|
|
||||||
result_docs, length_func, token_max, **kwargs
|
|
||||||
)
|
|
||||||
result_docs = []
|
|
||||||
for docs in new_result_doc_list:
|
|
||||||
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
||||||
result_docs.append(new_doc)
|
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
_results = [r[self.llm_chain.output_key] for r in results]
|
intermediate_steps = [r[question_result_key] for r in map_results]
|
||||||
extra_return_dict = {"intermediate_steps": _results}
|
extra_return_dict["intermediate_steps"] = intermediate_steps
|
||||||
else:
|
return result, extra_return_dict
|
||||||
extra_return_dict = {}
|
|
||||||
return result_docs, extra_return_dict
|
|
||||||
|
|
||||||
def _process_results(
|
|
||||||
self,
|
|
||||||
results: List[Dict],
|
|
||||||
docs: List[Document],
|
|
||||||
token_max: int = 3000,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Tuple[str, dict]:
|
|
||||||
result_docs, extra_return_dict = self._process_results_common(
|
|
||||||
results, docs, token_max, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
output = self.combine_document_chain.run(
|
|
||||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
return output, extra_return_dict
|
|
||||||
|
|
||||||
async def _aprocess_results(
|
|
||||||
self,
|
|
||||||
results: List[Dict],
|
|
||||||
docs: List[Document],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Tuple[str, dict]:
|
|
||||||
result_docs, extra_return_dict = self._process_results_common(
|
|
||||||
results, docs, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
output = await self.combine_document_chain.arun(
|
|
||||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
|
||||||
)
|
|
||||||
return output, extra_return_dict
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
|
@@ -14,7 +14,48 @@ from langchain.output_parsers.regex import RegexParser
|
|||||||
|
|
||||||
|
|
||||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||||
"""Combining documents by mapping a chain over them, then reranking results."""
|
"""Combining documents by mapping a chain over them, then reranking results.
|
||||||
|
|
||||||
|
This algorithm calls an LLMChain on each input document. The LLMChain is expected
|
||||||
|
to have an OutputParser that parses the result into both an answer (`answer_key`)
|
||||||
|
and a score (`rank_key`). The answer with the highest score is then returned.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
from langchain.output_parsers.regex import RegexParser
|
||||||
|
|
||||||
|
document_variable_name = "context"
|
||||||
|
llm = OpenAI()
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
# The actual prompt will need to be a lot more complex, this is just
|
||||||
|
# an example.
|
||||||
|
prompt_template = (
|
||||||
|
"Use the following context to tell me the chemical formula "
|
||||||
|
"for water. Output both your answer and a score of how confident "
|
||||||
|
"you are. Context: {content}"
|
||||||
|
)
|
||||||
|
output_parser = RegexParser(
|
||||||
|
regex=r"(.*?)\nScore: (.*)",
|
||||||
|
output_keys=["answer", "score"],
|
||||||
|
)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template=prompt_template,
|
||||||
|
input_variables=["context"],
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain = MapRerankDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_variable_name=document_variable_name,
|
||||||
|
rank_key="score",
|
||||||
|
answer_key="answer",
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
"""Chain to apply to each document individually."""
|
"""Chain to apply to each document individually."""
|
||||||
@@ -26,7 +67,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
answer_key: str
|
answer_key: str
|
||||||
"""Key in output of llm_chain to return as answer."""
|
"""Key in output of llm_chain to return as answer."""
|
||||||
metadata_keys: Optional[List[str]] = None
|
metadata_keys: Optional[List[str]] = None
|
||||||
|
"""Additional metadata from the chosen document to return."""
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
|
"""Return intermediate steps.
|
||||||
|
Intermediate steps include the results of calling llm_chain on each document."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@@ -96,6 +140,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Combine documents in a map rerank manner.
|
"""Combine documents in a map rerank manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reranking the results.
|
Combine by mapping first chain over all documents, then reranking the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
"""
|
"""
|
||||||
results = self.llm_chain.apply_and_parse(
|
results = self.llm_chain.apply_and_parse(
|
||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
@@ -110,6 +164,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Combine documents in a map rerank manner.
|
"""Combine documents in a map rerank manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reranking the results.
|
Combine by mapping first chain over all documents, then reranking the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
"""
|
"""
|
||||||
results = await self.llm_chain.aapply_and_parse(
|
results = await self.llm_chain.aapply_and_parse(
|
||||||
# FYI - this is parallelized and so it is fast.
|
# FYI - this is parallelized and so it is fast.
|
||||||
|
277
langchain/chains/combine_documents/reduce.py
Normal file
277
langchain/chains/combine_documents/reduce.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Combine many documents together by recursively reducing them."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class CombineDocsProtocol(Protocol):
|
||||||
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||||
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCombineDocsProtocol(Protocol):
|
||||||
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||||
|
"""Async nterface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
|
def _split_list_of_docs(
|
||||||
|
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||||
|
) -> List[List[Document]]:
|
||||||
|
new_result_doc_list = []
|
||||||
|
_sub_result_docs = []
|
||||||
|
for doc in docs:
|
||||||
|
_sub_result_docs.append(doc)
|
||||||
|
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
||||||
|
if _num_tokens > token_max:
|
||||||
|
if len(_sub_result_docs) == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"A single document was longer than the context length,"
|
||||||
|
" we cannot handle this."
|
||||||
|
)
|
||||||
|
new_result_doc_list.append(_sub_result_docs[:-1])
|
||||||
|
_sub_result_docs = _sub_result_docs[-1:]
|
||||||
|
new_result_doc_list.append(_sub_result_docs)
|
||||||
|
return new_result_doc_list
|
||||||
|
|
||||||
|
|
||||||
|
def _collapse_docs(
|
||||||
|
docs: List[Document],
|
||||||
|
combine_document_func: CombineDocsProtocol,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Document:
|
||||||
|
result = combine_document_func(docs, **kwargs)
|
||||||
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||||
|
for doc in docs[1:]:
|
||||||
|
for k, v in doc.metadata.items():
|
||||||
|
if k in combined_metadata:
|
||||||
|
combined_metadata[k] += f", {v}"
|
||||||
|
else:
|
||||||
|
combined_metadata[k] = str(v)
|
||||||
|
return Document(page_content=result, metadata=combined_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
async def _acollapse_docs(
|
||||||
|
docs: List[Document],
|
||||||
|
combine_document_func: AsyncCombineDocsProtocol,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Document:
|
||||||
|
result = await combine_document_func(docs, **kwargs)
|
||||||
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||||
|
for doc in docs[1:]:
|
||||||
|
for k, v in doc.metadata.items():
|
||||||
|
if k in combined_metadata:
|
||||||
|
combined_metadata[k] += f", {v}"
|
||||||
|
else:
|
||||||
|
combined_metadata[k] = str(v)
|
||||||
|
return Document(page_content=result, metadata=combined_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||||
|
"""Combining documents by recursively reducing them.
|
||||||
|
|
||||||
|
This involves
|
||||||
|
- combine_documents_chain
|
||||||
|
- collapse_documents_chain
|
||||||
|
|
||||||
|
`combine_documents_chain` is ALWAYS provided. This is final chain that is called.
|
||||||
|
We pass all previous results to this chain, and the output of this chain is
|
||||||
|
returned as a final result.
|
||||||
|
|
||||||
|
`collapse_documents_chain` is used if the documents passed in are too many to all
|
||||||
|
be passed to `combine_documents_chain` in one go. In this case,
|
||||||
|
`collapse_documents_chain` is called recursively on as big of groups of documents
|
||||||
|
as are allowed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import (
|
||||||
|
StuffDocumentsChain, LLMChain, ReduceDocumentsChain
|
||||||
|
)
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
# This controls how each document will be formatted. Specifically,
|
||||||
|
# it will be passed to `format_document` - see that function for more
|
||||||
|
# details.
|
||||||
|
document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"],
|
||||||
|
template="{page_content}"
|
||||||
|
)
|
||||||
|
document_variable_name = "context"
|
||||||
|
llm = OpenAI()
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Summarize this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
)
|
||||||
|
# If we wanted to, we could also pass in collapse_documents_chain
|
||||||
|
# which is specifically aimed at collapsing documents BEFORE
|
||||||
|
# the final call.
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Collapse this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
collapse_documents_chain = StuffDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_documents_chain,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
combine_documents_chain: BaseCombineDocumentsChain
|
||||||
|
"""Final chain to call to combine documents.
|
||||||
|
This is typically a StuffDocumentsChain."""
|
||||||
|
collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
|
||||||
|
"""Chain to use to collapse documents if needed until they can all fit.
|
||||||
|
If None, will use the combine_documents_chain.
|
||||||
|
This is typically a StuffDocumentsChain."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
||||||
|
if self.collapse_documents_chain is not None:
|
||||||
|
return self.collapse_documents_chain
|
||||||
|
else:
|
||||||
|
return self.combine_documents_chain
|
||||||
|
|
||||||
|
def combine_docs(
|
||||||
|
self,
|
||||||
|
docs: List[Document],
|
||||||
|
token_max: int = 3000,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine multiple documents recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine, assumed that each one is less than
|
||||||
|
`token_max`.
|
||||||
|
token_max: Recursively creates groups of documents less than this number
|
||||||
|
of tokens.
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
|
result_docs, extra_return_dict = self._collapse(
|
||||||
|
docs, token_max, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
return self.combine_documents_chain.combine_docs(
|
||||||
|
docs=result_docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def acombine_docs(
|
||||||
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""Combine multiple documents recursively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine, assumed that each one is less than
|
||||||
|
`token_max`.
|
||||||
|
token_max: Recursively creates groups of documents less than this number
|
||||||
|
of tokens.
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
|
result_docs, extra_return_dict = await self._acollapse(
|
||||||
|
docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
return await self.combine_documents_chain.acombine_docs(
|
||||||
|
docs=result_docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _collapse(
|
||||||
|
self,
|
||||||
|
docs: List[Document],
|
||||||
|
token_max: int = 3000,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[List[Document], dict]:
|
||||||
|
result_docs = docs
|
||||||
|
length_func = self.combine_documents_chain.prompt_length
|
||||||
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
|
||||||
|
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||||
|
return self._collapse_chain.run(
|
||||||
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
while num_tokens is not None and num_tokens > token_max:
|
||||||
|
new_result_doc_list = _split_list_of_docs(
|
||||||
|
result_docs, length_func, token_max, **kwargs
|
||||||
|
)
|
||||||
|
result_docs = []
|
||||||
|
for docs in new_result_doc_list:
|
||||||
|
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||||
|
result_docs.append(new_doc)
|
||||||
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
return result_docs, {}
|
||||||
|
|
||||||
|
async def _acollapse(
|
||||||
|
self,
|
||||||
|
docs: List[Document],
|
||||||
|
token_max: int = 3000,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[List[Document], dict]:
|
||||||
|
result_docs = docs
|
||||||
|
length_func = self.combine_documents_chain.prompt_length
|
||||||
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
|
||||||
|
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||||
|
return await self._collapse_chain.arun(
|
||||||
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
while num_tokens is not None and num_tokens > token_max:
|
||||||
|
new_result_doc_list = _split_list_of_docs(
|
||||||
|
result_docs, length_func, token_max, **kwargs
|
||||||
|
)
|
||||||
|
result_docs = []
|
||||||
|
for docs in new_result_doc_list:
|
||||||
|
new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
||||||
|
result_docs.append(new_doc)
|
||||||
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
return result_docs, {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _chain_type(self) -> str:
|
||||||
|
return "reduce_documents_chain"
|
@@ -9,12 +9,11 @@ from pydantic import Extra, Field, root_validator
|
|||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import (
|
from langchain.chains.combine_documents.base import (
|
||||||
BaseCombineDocumentsChain,
|
BaseCombineDocumentsChain,
|
||||||
format_document,
|
|
||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate, format_document
|
||||||
|
|
||||||
|
|
||||||
def _get_default_document_prompt() -> PromptTemplate:
|
def _get_default_document_prompt() -> PromptTemplate:
|
||||||
@@ -22,7 +21,55 @@ def _get_default_document_prompt() -> PromptTemplate:
|
|||||||
|
|
||||||
|
|
||||||
class RefineDocumentsChain(BaseCombineDocumentsChain):
|
class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||||
"""Combine documents by doing a first pass and then refining on more documents."""
|
"""Combine documents by doing a first pass and then refining on more documents.
|
||||||
|
|
||||||
|
This algorithm first calls `initial_llm_chain` on the first document, passing
|
||||||
|
that first document in with the variable name `document_variable_name`, and
|
||||||
|
produces a new variable with the variable name `initial_response_name`.
|
||||||
|
|
||||||
|
Then, it loops over every remaining document. This is called the "refine" step.
|
||||||
|
It calls `refine_llm_chain`,
|
||||||
|
passing in that document with the variable name `document_variable_name`
|
||||||
|
as well as the previous response with the variable name `initial_response_name`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import RefineDocumentsChain, LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
# This controls how each document will be formatted. Specifically,
|
||||||
|
# it will be passed to `format_document` - see that function for more
|
||||||
|
# details.
|
||||||
|
document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"],
|
||||||
|
template="{page_content}"
|
||||||
|
)
|
||||||
|
document_variable_name = "context"
|
||||||
|
llm = OpenAI()
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Summarize this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
initial_response_name = "prev_response"
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name` as well as `initial_response_name`
|
||||||
|
prompt_refine = PromptTemplate.from_template(
|
||||||
|
"Here's your first summary: {prev_response}. "
|
||||||
|
"Now add to it based on the following context: {context}"
|
||||||
|
)
|
||||||
|
llm_chain_refine = LLMChain(llm=llm, prompt=prompt_refine)
|
||||||
|
chain = RefineDocumentsChain(
|
||||||
|
initial_llm_chain=initial_llm_chain,
|
||||||
|
refine_llm_chain=refine_llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name,
|
||||||
|
initial_response_name=initial_response_name,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
initial_llm_chain: LLMChain
|
initial_llm_chain: LLMChain
|
||||||
"""LLM chain to use on initial document."""
|
"""LLM chain to use on initial document."""
|
||||||
@@ -36,7 +83,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
document_prompt: BasePromptTemplate = Field(
|
document_prompt: BasePromptTemplate = Field(
|
||||||
default_factory=_get_default_document_prompt
|
default_factory=_get_default_document_prompt
|
||||||
)
|
)
|
||||||
"""Prompt to use to format each document."""
|
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
"""Return the results of the refine steps in the output."""
|
"""Return the results of the refine steps in the output."""
|
||||||
|
|
||||||
@@ -89,7 +136,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||||
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
|
||||||
refine_steps = [res]
|
refine_steps = [res]
|
||||||
@@ -103,7 +161,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Combine by mapping first chain over all, then stuffing into final chain."""
|
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to combine
|
||||||
|
callbacks: Callbacks to be passed through
|
||||||
|
**kwargs: additional parameters to be passed to LLM calls (like other
|
||||||
|
input variables besides the documents)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
inputs = self._construct_initial_inputs(docs, **kwargs)
|
inputs = self._construct_initial_inputs(docs, **kwargs)
|
||||||
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
|
||||||
refine_steps = [res]
|
refine_steps = [res]
|
||||||
|
@@ -7,12 +7,11 @@ from pydantic import Extra, Field, root_validator
|
|||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import (
|
from langchain.chains.combine_documents.base import (
|
||||||
BaseCombineDocumentsChain,
|
BaseCombineDocumentsChain,
|
||||||
format_document,
|
|
||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate, format_document
|
||||||
|
|
||||||
|
|
||||||
def _get_default_document_prompt() -> PromptTemplate:
|
def _get_default_document_prompt() -> PromptTemplate:
|
||||||
@@ -20,14 +19,50 @@ def _get_default_document_prompt() -> PromptTemplate:
|
|||||||
|
|
||||||
|
|
||||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||||
"""Chain that combines documents by stuffing into context."""
|
"""Chain that combines documents by stuffing into context.
|
||||||
|
|
||||||
|
This chain takes a list of documents and first combines them into a single string.
|
||||||
|
It does this by formatting each document into a string with the `document_prompt`
|
||||||
|
and then joining them together with `document_separator`. It then adds that new
|
||||||
|
string to the inputs with the variable name set by `document_variable_name`.
|
||||||
|
Those inputs are then passed to the `llm_chain`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
# This controls how each document will be formatted. Specifically,
|
||||||
|
# it will be passed to `format_document` - see that function for more
|
||||||
|
# details.
|
||||||
|
document_prompt = PromptTemplate(
|
||||||
|
input_variables=["page_content"],
|
||||||
|
template="{page_content}"
|
||||||
|
)
|
||||||
|
document_variable_name = "context"
|
||||||
|
llm = OpenAI()
|
||||||
|
# The prompt here should take as an input variable the
|
||||||
|
# `document_variable_name`
|
||||||
|
prompt = PromptTemplate.from_template(
|
||||||
|
"Summarize this content: {context}"
|
||||||
|
)
|
||||||
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain = StuffDocumentsChain(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
document_prompt=document_prompt,
|
||||||
|
document_variable_name=document_variable_name
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
"""LLM wrapper to use after formatting documents."""
|
"""LLM chain which is called with the formatted document string,
|
||||||
|
along with any other inputs."""
|
||||||
document_prompt: BasePromptTemplate = Field(
|
document_prompt: BasePromptTemplate = Field(
|
||||||
default_factory=_get_default_document_prompt
|
default_factory=_get_default_document_prompt
|
||||||
)
|
)
|
||||||
"""Prompt to use to format each document."""
|
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||||
document_variable_name: str
|
document_variable_name: str
|
||||||
"""The variable name in the llm_chain to put the documents in.
|
"""The variable name in the llm_chain to put the documents in.
|
||||||
If only one variable in the llm_chain, this need not be provided."""
|
If only one variable in the llm_chain, this need not be provided."""
|
||||||
@@ -42,7 +77,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||||
"""Get default document variable name, if not provided."""
|
"""Get default document variable name, if not provided.
|
||||||
|
|
||||||
|
If only one variable is present in the llm_chain.prompt,
|
||||||
|
we can infer that the formatted documents should be passed in
|
||||||
|
with this variable name.
|
||||||
|
"""
|
||||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||||
if "document_variable_name" not in values:
|
if "document_variable_name" not in values:
|
||||||
if len(llm_chain_variables) == 1:
|
if len(llm_chain_variables) == 1:
|
||||||
@@ -61,6 +101,20 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||||
|
"""Construct inputs from kwargs and docs.
|
||||||
|
|
||||||
|
Format and the join all the documents together into one input with name
|
||||||
|
`self.document_variable_name`. The pluck any additional variables
|
||||||
|
from **kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to format and then join into single input
|
||||||
|
**kwargs: additional inputs to chain, will pluck any other required
|
||||||
|
arguments from here.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary of inputs to LLMChain
|
||||||
|
"""
|
||||||
# Format each document according to the prompt
|
# Format each document according to the prompt
|
||||||
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
|
||||||
# Join the documents together to put them in the prompt.
|
# Join the documents together to put them in the prompt.
|
||||||
@@ -73,7 +127,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||||
"""Get the prompt length by formatting the prompt."""
|
"""Return the prompt length given the documents passed in.
|
||||||
|
|
||||||
|
This can be used by a caller to determine whether passing in a list
|
||||||
|
of documents would exceed a certain prompt length. This useful when
|
||||||
|
trying to ensure that the size of a prompt remains below a certain
|
||||||
|
context limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List[Document], a list of documents to use to calculate the
|
||||||
|
total prompt length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Returns None if the method does not depend on the prompt length,
|
||||||
|
otherwise the length of the prompt in tokens.
|
||||||
|
"""
|
||||||
inputs = self._get_inputs(docs, **kwargs)
|
inputs = self._get_inputs(docs, **kwargs)
|
||||||
prompt = self.llm_chain.prompt.format(**inputs)
|
prompt = self.llm_chain.prompt.format(**inputs)
|
||||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||||
@@ -81,7 +149,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Stuff all documents into one prompt and pass to LLM."""
|
"""Stuff all documents into one prompt and pass to LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to join together into one variable
|
||||||
|
callbacks: Optional callbacks to pass along
|
||||||
|
**kwargs: additional parameters to use to get inputs to LLMChain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
inputs = self._get_inputs(docs, **kwargs)
|
inputs = self._get_inputs(docs, **kwargs)
|
||||||
# Call predict on the LLM.
|
# Call predict on the LLM.
|
||||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||||
@@ -89,7 +167,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
"""Stuff all documents into one prompt and pass to LLM."""
|
"""Stuff all documents into one prompt and pass to LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs: List of documents to join together into one variable
|
||||||
|
callbacks: Optional callbacks to pass along
|
||||||
|
**kwargs: additional parameters to use to get inputs to LLMChain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The first element returned is the single string output. The second
|
||||||
|
element returned is a dictionary of other keys to return.
|
||||||
|
"""
|
||||||
inputs = self._get_inputs(docs, **kwargs)
|
inputs = self._get_inputs(docs, **kwargs)
|
||||||
# Call predict on the LLM.
|
# Call predict on the LLM.
|
||||||
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
|
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
|
||||||
|
@@ -5,6 +5,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.api.base import APIChain
|
from langchain.chains.api.base import APIChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@@ -117,9 +118,9 @@ def _load_map_reduce_documents_chain(
|
|||||||
|
|
||||||
if "combine_document_chain" in config:
|
if "combine_document_chain" in config:
|
||||||
combine_document_chain_config = config.pop("combine_document_chain")
|
combine_document_chain_config = config.pop("combine_document_chain")
|
||||||
combine_document_chain = load_chain_from_config(combine_document_chain_config)
|
combine_documents_chain = load_chain_from_config(combine_document_chain_config)
|
||||||
elif "combine_document_chain_path" in config:
|
elif "combine_document_chain_path" in config:
|
||||||
combine_document_chain = load_chain(config.pop("combine_document_chain_path"))
|
combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"One of `combine_document_chain` or "
|
"One of `combine_document_chain` or "
|
||||||
@@ -128,17 +129,24 @@ def _load_map_reduce_documents_chain(
|
|||||||
if "collapse_document_chain" in config:
|
if "collapse_document_chain" in config:
|
||||||
collapse_document_chain_config = config.pop("collapse_document_chain")
|
collapse_document_chain_config = config.pop("collapse_document_chain")
|
||||||
if collapse_document_chain_config is None:
|
if collapse_document_chain_config is None:
|
||||||
collapse_document_chain = None
|
collapse_documents_chain = None
|
||||||
else:
|
else:
|
||||||
collapse_document_chain = load_chain_from_config(
|
collapse_documents_chain = load_chain_from_config(
|
||||||
collapse_document_chain_config
|
collapse_document_chain_config
|
||||||
)
|
)
|
||||||
elif "collapse_document_chain_path" in config:
|
elif "collapse_document_chain_path" in config:
|
||||||
collapse_document_chain = load_chain(config.pop("collapse_document_chain_path"))
|
collapse_documents_chain = load_chain(
|
||||||
|
config.pop("collapse_document_chain_path")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
collapse_documents_chain = None
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_documents_chain,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
collapse_document_chain=collapse_document_chain,
|
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ from pydantic import Extra
|
|||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@@ -44,14 +45,17 @@ class MapReduceChain(Chain):
|
|||||||
) -> MapReduceChain:
|
) -> MapReduceChain:
|
||||||
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
"""Construct a map-reduce chain that uses the chain for map and reduce."""
|
||||||
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
|
||||||
reduce_chain = StuffDocumentsChain(
|
stuff_chain = StuffDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
|
**(reduce_chain_kwargs if reduce_chain_kwargs else {}),
|
||||||
)
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=stuff_chain
|
||||||
|
)
|
||||||
combine_documents_chain = MapReduceDocumentsChain(
|
combine_documents_chain = MapReduceDocumentsChain(
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
combine_document_chain=reduce_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**(combine_chain_kwargs if combine_chain_kwargs else {}),
|
**(combine_chain_kwargs if combine_chain_kwargs else {}),
|
||||||
)
|
)
|
||||||
|
@@ -14,6 +14,7 @@ from langchain.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
@@ -58,13 +59,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
document_variable_name="summaries",
|
document_variable_name="summaries",
|
||||||
)
|
)
|
||||||
combine_document_chain = MapReduceDocumentsChain(
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_results_chain
|
||||||
|
)
|
||||||
|
combine_documents_chain = MapReduceDocumentsChain(
|
||||||
llm_chain=llm_question_chain,
|
llm_chain=llm_question_chain,
|
||||||
combine_document_chain=combine_results_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name="context",
|
document_variable_name="context",
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
combine_documents_chain=combine_document_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,10 +82,10 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
) -> BaseQAWithSourcesChain:
|
) -> BaseQAWithSourcesChain:
|
||||||
"""Load chain from chain type."""
|
"""Load chain from chain type."""
|
||||||
_chain_kwargs = chain_type_kwargs or {}
|
_chain_kwargs = chain_type_kwargs or {}
|
||||||
combine_document_chain = load_qa_with_sources_chain(
|
combine_documents_chain = load_qa_with_sources_chain(
|
||||||
llm, chain_type=chain_type, **_chain_kwargs
|
llm, chain_type=chain_type, **_chain_kwargs
|
||||||
)
|
)
|
||||||
return cls(combine_documents_chain=combine_document_chain, **kwargs)
|
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@@ -110,7 +114,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_naming(cls, values: Dict) -> Dict:
|
def validate_naming(cls, values: Dict) -> Dict:
|
||||||
"""Fix backwards compatability in naming."""
|
"""Fix backwards compatibility in naming."""
|
||||||
if "combine_document_chain" in values:
|
if "combine_document_chain" in values:
|
||||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||||
return values
|
return values
|
||||||
|
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||||
@@ -83,7 +84,7 @@ def _load_map_reduce_chain(
|
|||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||||
combine_document_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
@@ -107,11 +108,14 @@ def _load_map_reduce_chain(
|
|||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
document_prompt=document_prompt,
|
document_prompt=document_prompt,
|
||||||
)
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_chain,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
collapse_document_chain=collapse_chain,
|
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@@ -4,6 +4,7 @@ from typing import Any, Mapping, Optional, Protocol
|
|||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
|
||||||
@@ -122,7 +123,7 @@ def _load_map_reduce_chain(
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
combine_document_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@@ -150,11 +151,14 @@ def _load_map_reduce_chain(
|
|||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
)
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_chain,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
collapse_document_chain=collapse_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
callback_manager=callback_manager,
|
callback_manager=callback_manager,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
from typing import Any, Mapping, Optional, Protocol
|
from typing import Any, Mapping, Optional, Protocol
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
||||||
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
||||||
@@ -53,7 +54,7 @@ def _load_map_reduce_chain(
|
|||||||
_reduce_llm = reduce_llm or llm
|
_reduce_llm = reduce_llm or llm
|
||||||
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
|
||||||
# TODO: document prompt
|
# TODO: document prompt
|
||||||
combine_document_chain = StuffDocumentsChain(
|
combine_documents_chain = StuffDocumentsChain(
|
||||||
llm_chain=reduce_chain,
|
llm_chain=reduce_chain,
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
@@ -75,11 +76,14 @@ def _load_map_reduce_chain(
|
|||||||
),
|
),
|
||||||
document_variable_name=combine_document_variable_name,
|
document_variable_name=combine_document_variable_name,
|
||||||
)
|
)
|
||||||
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
collapse_documents_chain=collapse_chain,
|
||||||
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
combine_document_chain=combine_document_chain,
|
reduce_documents_chain=reduce_documents_chain,
|
||||||
document_variable_name=map_reduce_document_variable_name,
|
document_variable_name=map_reduce_document_variable_name,
|
||||||
collapse_document_chain=collapse_chain,
|
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@@ -28,7 +28,7 @@ from langchain.schema.output_parser import (
|
|||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.prompt_template import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate, format_document
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
RUN_KEY = "__run"
|
||||||
@@ -66,4 +66,5 @@ __all__ = [
|
|||||||
"BaseOutputParser",
|
"BaseOutputParser",
|
||||||
"BaseLLMOutputParser",
|
"BaseLLMOutputParser",
|
||||||
"BasePromptTemplate",
|
"BasePromptTemplate",
|
||||||
|
"format_document",
|
||||||
]
|
]
|
||||||
|
@@ -9,7 +9,9 @@ import yaml
|
|||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import BaseOutputParser, PromptValue
|
from langchain.schema.document import Document
|
||||||
|
from langchain.schema.output_parser import BaseOutputParser
|
||||||
|
from langchain.schema.prompt import PromptValue
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(Serializable, ABC):
|
class BasePromptTemplate(Serializable, ABC):
|
||||||
@@ -137,3 +139,48 @@ class BasePromptTemplate(Serializable, ABC):
|
|||||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||||
|
"""Format a document into a string based on a prompt template.
|
||||||
|
|
||||||
|
First, this pulls information from the document from two sources:
|
||||||
|
|
||||||
|
1. `page_content`: this takes the information from the `document.page_content`
|
||||||
|
and assigns it to a variable named `page_content`.
|
||||||
|
2. metadata: this takes information from `document.metadata` and assigns
|
||||||
|
it to variables of the same name.
|
||||||
|
|
||||||
|
Those variables are then passed into the `prompt` to produce a formatted string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc: Document, the page_content and metadata will be used to create
|
||||||
|
the final string.
|
||||||
|
prompt: BasePromptTemplate, will be used to format the page_content
|
||||||
|
and metadata into the final string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
string of the document formatted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
doc = Document(page_content="This is a joke", metadata={"page": "1"})
|
||||||
|
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
|
||||||
|
format_document(doc, prompt)
|
||||||
|
>>> "Page 1: This is a joke"
|
||||||
|
"""
|
||||||
|
base_info = {"page_content": doc.page_content, **doc.metadata}
|
||||||
|
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||||
|
if len(missing_metadata) > 0:
|
||||||
|
required_metadata = [
|
||||||
|
iv for iv in prompt.input_variables if iv != "page_content"
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
f"Document prompt requires documents to have metadata variables: "
|
||||||
|
f"{required_metadata}. Received document with missing metadata: "
|
||||||
|
f"{list(missing_metadata)}."
|
||||||
|
)
|
||||||
|
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||||
|
return prompt.format(**document_info)
|
||||||
|
@@ -5,12 +5,12 @@ from typing import Any, List
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
from langchain.chains.combine_documents.base import format_document
|
from langchain.chains.combine_documents.reduce import (
|
||||||
from langchain.chains.combine_documents.map_reduce import (
|
|
||||||
_collapse_docs,
|
_collapse_docs,
|
||||||
_split_list_of_docs,
|
_split_list_of_docs,
|
||||||
)
|
)
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.schema import format_document
|
||||||
|
|
||||||
|
|
||||||
def _fake_docs_len_func(docs: List[Document]) -> int:
|
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||||
@@ -28,13 +28,6 @@ def test__split_list_long_single_doc() -> None:
|
|||||||
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
||||||
|
|
||||||
|
|
||||||
def test__split_list_long_pair_doc() -> None:
|
|
||||||
"""Test splitting of a list with two medium docs."""
|
|
||||||
docs = [Document(page_content="foo" * 30)] * 2
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
|
||||||
|
|
||||||
|
|
||||||
def test__split_list_single_doc() -> None:
|
def test__split_list_single_doc() -> None:
|
||||||
"""Test splitting works with just a single doc."""
|
"""Test splitting works with just a single doc."""
|
||||||
docs = [Document(page_content="foo")]
|
docs = [Document(page_content="foo")]
|
||||||
|
@@ -86,8 +86,8 @@ def test_imports() -> None:
|
|||||||
from langchain.document_loaders import BSHTMLLoader # noqa: F401
|
from langchain.document_loaders import BSHTMLLoader # noqa: F401
|
||||||
from langchain.embeddings import OpenAIEmbeddings # noqa: F401
|
from langchain.embeddings import OpenAIEmbeddings # noqa: F401
|
||||||
from langchain.llms import OpenAI # noqa: F401
|
from langchain.llms import OpenAI # noqa: F401
|
||||||
from langchain.prompts import BasePromptTemplate # noqa: F401
|
|
||||||
from langchain.retrievers import VespaRetriever # noqa: F401
|
from langchain.retrievers import VespaRetriever # noqa: F401
|
||||||
|
from langchain.schema import BasePromptTemplate # noqa: F401
|
||||||
from langchain.tools import DuckDuckGoSearchResults # noqa: F401
|
from langchain.tools import DuckDuckGoSearchResults # noqa: F401
|
||||||
from langchain.utilities import SerpAPIWrapper # noqa: F401
|
from langchain.utilities import SerpAPIWrapper # noqa: F401
|
||||||
from langchain.vectorstores import FAISS # noqa: F401
|
from langchain.vectorstores import FAISS # noqa: F401
|
||||||
|
Reference in New Issue
Block a user