Docs combine document chain (#6994)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harrison Chase
2023-07-04 11:51:04 -07:00
committed by GitHub
parent 81eebc4070
commit 0ad984fa27
17 changed files with 820 additions and 205 deletions

View File

@@ -4,6 +4,7 @@ from langchain.chains.api.openapi.chain import OpenAPIEndpointChain
from langchain.chains.combine_documents.base import AnalyzeDocumentChain from langchain.chains.combine_documents.base import AnalyzeDocumentChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain from langchain.chains.constitutional_ai.base import ConstitutionalChain
@@ -111,4 +112,5 @@ __all__ = [
"MapRerankDocumentsChain", "MapRerankDocumentsChain",
"MapReduceDocumentsChain", "MapReduceDocumentsChain",
"RefineDocumentsChain", "RefineDocumentsChain",
"ReduceDocumentsChain",
] ]

View File

@@ -11,30 +11,20 @@ from langchain.callbacks.manager import (
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template."""
base_info = {"page_content": doc.page_content}
base_info.update(doc.metadata)
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)
class BaseCombineDocumentsChain(Chain, ABC): class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents.""" """Base interface for chains combining documents.
Subclasses of this chain deal with combining documents in a variety of
ways. This base class exists to add some uniformity in the interface these types
of chains should expose. Namely, they expect an input key related to the documents
to use (default `input_documents`), and then also expose a method to calculate
the length of a prompt from documents (useful for outside callers to use to
determine whether it's safe to pass a list of documents into this chain or whether
that will longer than the context length).
"""
input_key: str = "input_documents" #: :meta private: input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private: output_key: str = "output_text" #: :meta private:
@@ -58,25 +48,57 @@ class BaseCombineDocumentsChain(Chain, ABC):
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in. """Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length. This can be used by a caller to determine whether passing in a list
of documents would exceed a certain prompt length. This useful when
trying to ensure that the size of a prompt remains below a certain
context limit.
Args:
docs: List[Document], a list of documents to use to calculate the
total prompt length.
Returns:
Returns None if the method does not depend on the prompt length,
otherwise the length of the prompt in tokens.
""" """
return None return None
@abstractmethod @abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Combine documents into a single string.""" """Combine documents into a single string.
Args:
docs: List[Document], the documents to combine
**kwargs: Other parameters to use in combining documents, often
other inputs to the prompt.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
@abstractmethod @abstractmethod
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], **kwargs: Any self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine documents into a single string asynchronously.""" """Combine documents into a single string.
Args:
docs: List[Document], the documents to combine
**kwargs: Other parameters to use in combining documents, often
other inputs to the prompt.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
def _call( def _call(
self, self,
inputs: Dict[str, List[Document]], inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key] docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction # Other keys are assumed to be needed for LLM prediction
@@ -92,6 +114,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
inputs: Dict[str, List[Document]], inputs: Dict[str, List[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key] docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction # Other keys are assumed to be needed for LLM prediction
@@ -104,7 +127,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
class AnalyzeDocumentChain(Chain): class AnalyzeDocumentChain(Chain):
"""Chain that splits documents, then analyzes it in pieces.""" """Chain that splits documents, then analyzes it in pieces.
This chain is parameterized by a TextSplitter and a CombineDocumentsChain.
This chain takes a single document as input, and then splits it up into chunks
and then passes those chucks to the CombineDocumentsChain.
"""
input_key: str = "input_document" #: :meta private: input_key: str = "input_document" #: :meta private:
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter) text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
@@ -131,6 +159,7 @@ class AnalyzeDocumentChain(Chain):
inputs: Dict[str, str], inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Split document into chunks and pass to CombineDocumentsChain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key] document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document]) docs = self.text_splitter.create_documents([document])

View File

@@ -2,74 +2,97 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from typing import Any, Dict, List, Tuple
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
if len(_sub_result_docs) == 2:
raise ValueError(
"A single document was so long it could not be combined "
"with another document, we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
def _collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
class MapReduceDocumentsChain(BaseCombineDocumentsChain): class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then combining results.""" """Combining documents by mapping a chain over them, then combining results.
We first call `llm_chain` on each document individually, passing in the
`page_content` and any other kwargs. This is the `map` step.
We then process the results of that `map` step in a `reduce` step. This should
likely be a ReduceDocumentsChain.
Example:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain,
LLMChain,
ReduceDocumentsChain,
MapReduceDocumentsChain,
)
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)
# If we wanted to, we could also pass in collapse_documents_chain
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template(
"Collapse this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)
"""
llm_chain: LLMChain llm_chain: LLMChain
"""Chain to apply to each document individually.""" """Chain to apply to each document individually."""
combine_document_chain: BaseCombineDocumentsChain reduce_documents_chain: BaseCombineDocumentsChain
"""Chain to use to combine results of applying llm_chain to documents.""" """Chain to use to reduce the results of applying `llm_chain` to each doc.
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None This typically either a ReduceDocumentChain or StuffDocumentChain."""
"""Chain to use to collapse intermediary results if needed.
If None, will use the combine_document_chain."""
document_variable_name: str document_variable_name: str
"""The variable name in the llm_chain to put the documents in. """The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided.""" If only one variable in the llm_chain, this need not be provided."""
@@ -93,6 +116,29 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@root_validator(pre=True)
def get_reduce_chain(cls, values: Dict) -> Dict:
"""For backwards compatibility."""
if "combine_document_chain" in values:
if "reduce_documents_chain" in values:
raise ValueError(
"Both `reduce_documents_chain` and `combine_document_chain` "
"cannot be provided at the same time. `combine_document_chain` "
"is deprecated, please only provide `reduce_documents_chain`"
)
combine_chain = values["combine_document_chain"]
collapse_chain = values.get("collapse_document_chain")
reduce_chain = ReduceDocumentsChain(
combine_documents_chain=combine_chain,
collapse_documents_chain=collapse_chain,
)
values["reduce_documents_chain"] = reduce_chain
del values["combine_document_chain"]
if "collapse_document_chain" in values:
del values["collapse_document_chain"]
return values
@root_validator(pre=True) @root_validator(pre=True)
def get_return_intermediate_steps(cls, values: Dict) -> Dict: def get_return_intermediate_steps(cls, values: Dict) -> Dict:
"""For backwards compatibility.""" """For backwards compatibility."""
@@ -123,11 +169,31 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return values return values
@property @property
def _collapse_chain(self) -> BaseCombineDocumentsChain: def collapse_document_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_document_chain is not None: """Kept for backward compatibility."""
return self.collapse_document_chain if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
if self.reduce_documents_chain.collapse_documents_chain:
return self.reduce_documents_chain.collapse_documents_chain
else:
return self.reduce_documents_chain.combine_documents_chain
else: else:
return self.combine_document_chain raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
@property
def combine_document_chain(self) -> BaseCombineDocumentsChain:
"""Kept for backward compatibility."""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
return self.reduce_documents_chain.combine_documents_chain
else:
raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
def combine_docs( def combine_docs(
self, self,
@@ -141,14 +207,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reducing the results. Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents). This reducing can be done recursively if needed (if there are many documents).
""" """
results = self.llm_chain.apply( map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast. # FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs], [{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks, callbacks=callbacks,
) )
return self._process_results( question_result_key = self.llm_chain.output_key
results, docs, token_max, callbacks=callbacks, **kwargs result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(map_results)
]
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, callbacks=callbacks, **kwargs
) )
if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict["intermediate_steps"] = intermediate_steps
return result, extra_return_dict
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
@@ -158,83 +234,24 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Combine by mapping first chain over all documents, then reducing the results. Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents). This reducing can be done recursively if needed (if there are many documents).
""" """
results = await self.llm_chain.aapply( map_results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast. # FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks, callbacks=callbacks,
) )
return await self._aprocess_results(
results, docs, callbacks=callbacks, **kwargs
)
def _process_results_common(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
question_result_key = self.llm_chain.output_key question_result_key = self.llm_chain.output_key
result_docs = [ result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata) Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results` # This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(results) for i, r in enumerate(map_results)
] ]
length_func = self.combine_document_chain.prompt_length result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
num_tokens = length_func(result_docs, **kwargs) result_docs, callbacks=callbacks, **kwargs
)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
if self.return_intermediate_steps: if self.return_intermediate_steps:
_results = [r[self.llm_chain.output_key] for r in results] intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict = {"intermediate_steps": _results} extra_return_dict["intermediate_steps"] = intermediate_steps
else: return result, extra_return_dict
extra_return_dict = {}
return result_docs, extra_return_dict
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, token_max, callbacks=callbacks, **kwargs
)
output = self.combine_document_chain.run(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
async def _aprocess_results(
self,
results: List[Dict],
docs: List[Document],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, callbacks=callbacks, **kwargs
)
output = await self.combine_document_chain.arun(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:

View File

@@ -14,7 +14,48 @@ from langchain.output_parsers.regex import RegexParser
class MapRerankDocumentsChain(BaseCombineDocumentsChain): class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then reranking results.""" """Combining documents by mapping a chain over them, then reranking results.
This algorithm calls an LLMChain on each input document. The LLMChain is expected
to have an OutputParser that parses the result into both an answer (`answer_key`)
and a score (`rank_key`). The answer with the highest score is then returned.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.output_parsers.regex import RegexParser
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
# The actual prompt will need to be a lot more complex, this is just
# an example.
prompt_template = (
"Use the following context to tell me the chemical formula "
"for water. Output both your answer and a score of how confident "
"you are. Context: {content}"
)
output_parser = RegexParser(
regex=r"(.*?)\nScore: (.*)",
output_keys=["answer", "score"],
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context"],
output_parser=output_parser,
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = MapRerankDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
rank_key="score",
answer_key="answer",
)
"""
llm_chain: LLMChain llm_chain: LLMChain
"""Chain to apply to each document individually.""" """Chain to apply to each document individually."""
@@ -26,7 +67,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
answer_key: str answer_key: str
"""Key in output of llm_chain to return as answer.""" """Key in output of llm_chain to return as answer."""
metadata_keys: Optional[List[str]] = None metadata_keys: Optional[List[str]] = None
"""Additional metadata from the chosen document to return."""
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return intermediate steps.
Intermediate steps include the results of calling llm_chain on each document."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@@ -96,6 +140,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents in a map rerank manner. """Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results. Combine by mapping first chain over all documents, then reranking the results.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
""" """
results = self.llm_chain.apply_and_parse( results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast. # FYI - this is parallelized and so it is fast.
@@ -110,6 +164,16 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents in a map rerank manner. """Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results. Combine by mapping first chain over all documents, then reranking the results.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
""" """
results = await self.llm_chain.aapply_and_parse( results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast. # FYI - this is parallelized and so it is fast.

View File

@@ -0,0 +1,277 @@
"""Combine many documents together by recursively reducing them."""
from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple
from pydantic import Extra
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Async nterface for the combine_docs method."""
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
def _collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
async def _acollapse_docs(
docs: List[Document],
combine_document_func: AsyncCombineDocsProtocol,
**kwargs: Any,
) -> Document:
result = await combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by recursively reducing them.
This involves
- combine_documents_chain
- collapse_documents_chain
`combine_documents_chain` is ALWAYS provided. This is final chain that is called.
We pass all previous results to this chain, and the output of this chain is
returned as a final result.
`collapse_documents_chain` is used if the documents passed in are too many to all
be passed to `combine_documents_chain` in one go. In this case,
`collapse_documents_chain` is called recursively on as big of groups of documents
as are allowed.
Example:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain, LLMChain, ReduceDocumentsChain
)
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
# If we wanted to, we could also pass in collapse_documents_chain
# which is specifically aimed at collapsing documents BEFORE
# the final call.
prompt = PromptTemplate.from_template(
"Collapse this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
"""
combine_documents_chain: BaseCombineDocumentsChain
"""Final chain to call to combine documents.
This is typically a StuffDocumentsChain."""
collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
"""Chain to use to collapse documents if needed until they can all fit.
If None, will use the combine_documents_chain.
This is typically a StuffDocumentsChain."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_documents_chain is not None:
return self.collapse_documents_chain
else:
return self.combine_documents_chain
def combine_docs(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""Combine multiple documents recursively.
Args:
docs: List of documents to combine, assumed that each one is less than
`token_max`.
token_max: Recursively creates groups of documents less than this number
of tokens.
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
result_docs, extra_return_dict = self._collapse(
docs, token_max, callbacks=callbacks, **kwargs
)
return self.combine_documents_chain.combine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine multiple documents recursively.
Args:
docs: List of documents to combine, assumed that each one is less than
`token_max`.
token_max: Recursively creates groups of documents less than this number
of tokens.
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
result_docs, extra_return_dict = await self._acollapse(
docs, callbacks=callbacks, **kwargs
)
return await self.combine_documents_chain.acombine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
def _collapse(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
return result_docs, {}
async def _acollapse(
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return await self._collapse_chain.arun(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
return result_docs, {}
@property
def _chain_type(self) -> str:
return "reduce_documents_chain"

View File

@@ -9,12 +9,11 @@ from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import ( from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain, BaseCombineDocumentsChain,
format_document,
) )
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate from langchain.schema import BasePromptTemplate, format_document
def _get_default_document_prompt() -> PromptTemplate: def _get_default_document_prompt() -> PromptTemplate:
@@ -22,7 +21,55 @@ def _get_default_document_prompt() -> PromptTemplate:
class RefineDocumentsChain(BaseCombineDocumentsChain): class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Combine documents by doing a first pass and then refining on more documents.""" """Combine documents by doing a first pass and then refining on more documents.
This algorithm first calls `initial_llm_chain` on the first document, passing
that first document in with the variable name `document_variable_name`, and
produces a new variable with the variable name `initial_response_name`.
Then, it loops over every remaining document. This is called the "refine" step.
It calls `refine_llm_chain`,
passing in that document with the variable name `document_variable_name`
as well as the previous response with the variable name `initial_response_name`.
Example:
.. code-block:: python
from langchain.chains import RefineDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
initial_response_name = "prev_response"
# The prompt here should take as an input variable the
# `document_variable_name` as well as `initial_response_name`
prompt_refine = PromptTemplate.from_template(
"Here's your first summary: {prev_response}. "
"Now add to it based on the following context: {context}"
)
llm_chain_refine = LLMChain(llm=llm, prompt=prompt_refine)
chain = RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
)
"""
initial_llm_chain: LLMChain initial_llm_chain: LLMChain
"""LLM chain to use on initial document.""" """LLM chain to use on initial document."""
@@ -36,7 +83,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
document_prompt: BasePromptTemplate = Field( document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt default_factory=_get_default_document_prompt
) )
"""Prompt to use to format each document.""" """Prompt to use to format each document, gets passed to `format_document`."""
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return the results of the refine steps in the output.""" """Return the results of the refine steps in the output."""
@@ -89,7 +136,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.""" """Combine by mapping first chain over all, then stuffing into final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs) inputs = self._construct_initial_inputs(docs, **kwargs)
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs) res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps = [res] refine_steps = [res]
@@ -103,7 +161,18 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.""" """Combine by mapping first chain over all, then stuffing into final chain.
Args:
docs: List of documents to combine
callbacks: Callbacks to be passed through
**kwargs: additional parameters to be passed to LLM calls (like other
input variables besides the documents)
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._construct_initial_inputs(docs, **kwargs) inputs = self._construct_initial_inputs(docs, **kwargs)
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs) res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps = [res] refine_steps = [res]

View File

@@ -7,12 +7,11 @@ from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import ( from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain, BaseCombineDocumentsChain,
format_document,
) )
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate from langchain.schema import BasePromptTemplate, format_document
def _get_default_document_prompt() -> PromptTemplate: def _get_default_document_prompt() -> PromptTemplate:
@@ -20,14 +19,50 @@ def _get_default_document_prompt() -> PromptTemplate:
class StuffDocumentsChain(BaseCombineDocumentsChain): class StuffDocumentsChain(BaseCombineDocumentsChain):
"""Chain that combines documents by stuffing into context.""" """Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
llm_chain: LLMChain llm_chain: LLMChain
"""LLM wrapper to use after formatting documents.""" """LLM chain which is called with the formatted document string,
along with any other inputs."""
document_prompt: BasePromptTemplate = Field( document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt default_factory=_get_default_document_prompt
) )
"""Prompt to use to format each document.""" """Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str document_variable_name: str
"""The variable name in the llm_chain to put the documents in. """The variable name in the llm_chain to put the documents in.
If only one variable in the llm_chain, this need not be provided.""" If only one variable in the llm_chain, this need not be provided."""
@@ -42,7 +77,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
@root_validator(pre=True) @root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict: def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""Get default document variable name, if not provided.""" """Get default document variable name, if not provided.
If only one variable is present in the llm_chain.prompt,
we can infer that the formatted documents should be passed in
with this variable name.
"""
llm_chain_variables = values["llm_chain"].prompt.input_variables llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values: if "document_variable_name" not in values:
if len(llm_chain_variables) == 1: if len(llm_chain_variables) == 1:
@@ -61,6 +101,20 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values return values
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt # Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs] doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt. # Join the documents together to put them in the prompt.
@@ -73,7 +127,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return inputs return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Get the prompt length by formatting the prompt.""" """Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list
of documents would exceed a certain prompt length. This useful when
trying to ensure that the size of a prompt remains below a certain
context limit.
Args:
docs: List[Document], a list of documents to use to calculate the
total prompt length.
Returns:
Returns None if the method does not depend on the prompt length,
otherwise the length of the prompt in tokens.
"""
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs) prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt) return self.llm_chain.llm.get_num_tokens(prompt)
@@ -81,7 +149,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.""" """Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM. # Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {} return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@@ -89,7 +167,17 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.""" """Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs) inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM. # Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}

View File

@@ -5,6 +5,7 @@ from typing import Any, Union
import yaml import yaml
from langchain.chains import ReduceDocumentsChain
from langchain.chains.api.base import APIChain from langchain.chains.api.base import APIChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -117,9 +118,9 @@ def _load_map_reduce_documents_chain(
if "combine_document_chain" in config: if "combine_document_chain" in config:
combine_document_chain_config = config.pop("combine_document_chain") combine_document_chain_config = config.pop("combine_document_chain")
combine_document_chain = load_chain_from_config(combine_document_chain_config) combine_documents_chain = load_chain_from_config(combine_document_chain_config)
elif "combine_document_chain_path" in config: elif "combine_document_chain_path" in config:
combine_document_chain = load_chain(config.pop("combine_document_chain_path")) combine_documents_chain = load_chain(config.pop("combine_document_chain_path"))
else: else:
raise ValueError( raise ValueError(
"One of `combine_document_chain` or " "One of `combine_document_chain` or "
@@ -128,17 +129,24 @@ def _load_map_reduce_documents_chain(
if "collapse_document_chain" in config: if "collapse_document_chain" in config:
collapse_document_chain_config = config.pop("collapse_document_chain") collapse_document_chain_config = config.pop("collapse_document_chain")
if collapse_document_chain_config is None: if collapse_document_chain_config is None:
collapse_document_chain = None collapse_documents_chain = None
else: else:
collapse_document_chain = load_chain_from_config( collapse_documents_chain = load_chain_from_config(
collapse_document_chain_config collapse_document_chain_config
) )
elif "collapse_document_chain_path" in config: elif "collapse_document_chain_path" in config:
collapse_document_chain = load_chain(config.pop("collapse_document_chain_path")) collapse_documents_chain = load_chain(
config.pop("collapse_document_chain_path")
)
else:
collapse_documents_chain = None
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
combine_document_chain=combine_document_chain, reduce_documents_chain=reduce_documents_chain,
collapse_document_chain=collapse_document_chain,
**config, **config,
) )

View File

@@ -11,6 +11,7 @@ from pydantic import Extra
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -44,14 +45,17 @@ class MapReduceChain(Chain):
) -> MapReduceChain: ) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce.""" """Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
reduce_chain = StuffDocumentsChain( stuff_chain = StuffDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
callbacks=callbacks, callbacks=callbacks,
**(reduce_chain_kwargs if reduce_chain_kwargs else {}), **(reduce_chain_kwargs if reduce_chain_kwargs else {}),
) )
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=stuff_chain
)
combine_documents_chain = MapReduceDocumentsChain( combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
combine_document_chain=reduce_chain, reduce_documents_chain=reduce_documents_chain,
callbacks=callbacks, callbacks=callbacks,
**(combine_chain_kwargs if combine_chain_kwargs else {}), **(combine_chain_kwargs if combine_chain_kwargs else {}),
) )

View File

@@ -14,6 +14,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -58,13 +59,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
document_prompt=document_prompt, document_prompt=document_prompt,
document_variable_name="summaries", document_variable_name="summaries",
) )
combine_document_chain = MapReduceDocumentsChain( reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_results_chain
)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_question_chain, llm_chain=llm_question_chain,
combine_document_chain=combine_results_chain, reduce_documents_chain=reduce_documents_chain,
document_variable_name="context", document_variable_name="context",
) )
return cls( return cls(
combine_documents_chain=combine_document_chain, combine_documents_chain=combine_documents_chain,
**kwargs, **kwargs,
) )
@@ -78,10 +82,10 @@ class BaseQAWithSourcesChain(Chain, ABC):
) -> BaseQAWithSourcesChain: ) -> BaseQAWithSourcesChain:
"""Load chain from chain type.""" """Load chain from chain type."""
_chain_kwargs = chain_type_kwargs or {} _chain_kwargs = chain_type_kwargs or {}
combine_document_chain = load_qa_with_sources_chain( combine_documents_chain = load_qa_with_sources_chain(
llm, chain_type=chain_type, **_chain_kwargs llm, chain_type=chain_type, **_chain_kwargs
) )
return cls(combine_documents_chain=combine_document_chain, **kwargs) return cls(combine_documents_chain=combine_documents_chain, **kwargs)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@@ -110,7 +114,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
@root_validator(pre=True) @root_validator(pre=True)
def validate_naming(cls, values: Dict) -> Dict: def validate_naming(cls, values: Dict) -> Dict:
"""Fix backwards compatability in naming.""" """Fix backwards compatibility in naming."""
if "combine_document_chain" in values: if "combine_document_chain" in values:
values["combine_documents_chain"] = values.pop("combine_document_chain") values["combine_documents_chain"] = values.pop("combine_document_chain")
return values return values

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from typing import Any, Mapping, Optional, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@@ -83,7 +84,7 @@ def _load_map_reduce_chain(
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
combine_document_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
@@ -107,11 +108,14 @@ def _load_map_reduce_chain(
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
document_prompt=document_prompt, document_prompt=document_prompt,
) )
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain, reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
) )

View File

@@ -4,6 +4,7 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@@ -122,7 +123,7 @@ def _load_map_reduce_chain(
callbacks=callbacks, callbacks=callbacks,
) )
# TODO: document prompt # TODO: document prompt
combine_document_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
verbose=verbose, verbose=verbose,
@@ -150,11 +151,14 @@ def _load_map_reduce_chain(
verbose=verbose, verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
) )
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain, reduce_documents_chain=reduce_documents_chain,
verbose=verbose, verbose=verbose,
callback_manager=callback_manager, callback_manager=callback_manager,
callbacks=callbacks, callbacks=callbacks,

View File

@@ -2,6 +2,7 @@
from typing import Any, Mapping, Optional, Protocol from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.chains import ReduceDocumentsChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain
@@ -53,7 +54,7 @@ def _load_map_reduce_chain(
_reduce_llm = reduce_llm or llm _reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
# TODO: document prompt # TODO: document prompt
combine_document_chain = StuffDocumentsChain( combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain, llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
verbose=verbose, verbose=verbose,
@@ -75,11 +76,14 @@ def _load_map_reduce_chain(
), ),
document_variable_name=combine_document_variable_name, document_variable_name=combine_document_variable_name,
) )
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=map_chain, llm_chain=map_chain,
combine_document_chain=combine_document_chain, reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name, document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
verbose=verbose, verbose=verbose,
**kwargs, **kwargs,
) )

View File

@@ -28,7 +28,7 @@ from langchain.schema.output_parser import (
OutputParserException, OutputParserException,
) )
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate from langchain.schema.prompt_template import BasePromptTemplate, format_document
from langchain.schema.retriever import BaseRetriever from langchain.schema.retriever import BaseRetriever
RUN_KEY = "__run" RUN_KEY = "__run"
@@ -66,4 +66,5 @@ __all__ = [
"BaseOutputParser", "BaseOutputParser",
"BaseLLMOutputParser", "BaseLLMOutputParser",
"BasePromptTemplate", "BasePromptTemplate",
"format_document",
] ]

View File

@@ -9,7 +9,9 @@ import yaml
from pydantic import Field, root_validator from pydantic import Field, root_validator
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema import BaseOutputParser, PromptValue from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
class BasePromptTemplate(Serializable, ABC): class BasePromptTemplate(Serializable, ABC):
@@ -137,3 +139,48 @@ class BasePromptTemplate(Serializable, ABC):
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
1. `page_content`: this takes the information from the `document.page_content`
and assigns it to a variable named `page_content`.
2. metadata: this takes information from `document.metadata` and assigns
it to variables of the same name.
Those variables are then passed into the `prompt` to produce a formatted string.
Args:
doc: Document, the page_content and metadata will be used to create
the final string.
prompt: BasePromptTemplate, will be used to format the page_content
and metadata into the final string.
Returns:
string of the document formatted.
Example:
.. code-block:: python
from langchain.schema import Document
from langchain.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
prompt = PromptTemplate.from_template("Page {page}: {page_content}")
format_document(doc, prompt)
>>> "Page 1: This is a joke"
"""
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
document_info = {k: base_info[k] for k in prompt.input_variables}
return prompt.format(**document_info)

View File

@@ -5,12 +5,12 @@ from typing import Any, List
import pytest import pytest
from langchain import PromptTemplate from langchain import PromptTemplate
from langchain.chains.combine_documents.base import format_document from langchain.chains.combine_documents.reduce import (
from langchain.chains.combine_documents.map_reduce import (
_collapse_docs, _collapse_docs,
_split_list_of_docs, _split_list_of_docs,
) )
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.schema import format_document
def _fake_docs_len_func(docs: List[Document]) -> int: def _fake_docs_len_func(docs: List[Document]) -> int:
@@ -28,13 +28,6 @@ def test__split_list_long_single_doc() -> None:
_split_list_of_docs(docs, _fake_docs_len_func, 100) _split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_long_pair_doc() -> None:
"""Test splitting of a list with two medium docs."""
docs = [Document(page_content="foo" * 30)] * 2
with pytest.raises(ValueError):
_split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_single_doc() -> None: def test__split_list_single_doc() -> None:
"""Test splitting works with just a single doc.""" """Test splitting works with just a single doc."""
docs = [Document(page_content="foo")] docs = [Document(page_content="foo")]

View File

@@ -86,8 +86,8 @@ def test_imports() -> None:
from langchain.document_loaders import BSHTMLLoader # noqa: F401 from langchain.document_loaders import BSHTMLLoader # noqa: F401
from langchain.embeddings import OpenAIEmbeddings # noqa: F401 from langchain.embeddings import OpenAIEmbeddings # noqa: F401
from langchain.llms import OpenAI # noqa: F401 from langchain.llms import OpenAI # noqa: F401
from langchain.prompts import BasePromptTemplate # noqa: F401
from langchain.retrievers import VespaRetriever # noqa: F401 from langchain.retrievers import VespaRetriever # noqa: F401
from langchain.schema import BasePromptTemplate # noqa: F401
from langchain.tools import DuckDuckGoSearchResults # noqa: F401 from langchain.tools import DuckDuckGoSearchResults # noqa: F401
from langchain.utilities import SerpAPIWrapper # noqa: F401 from langchain.utilities import SerpAPIWrapper # noqa: F401
from langchain.vectorstores import FAISS # noqa: F401 from langchain.vectorstores import FAISS # noqa: F401