Compare commits

...

51 Commits

Author SHA1 Message Date
Bagatur
a1ecf827c3 stuff 2023-12-27 17:53:42 -05:00
Bagatur
5a40db5387 fmt 2023-12-27 17:48:00 -05:00
Bagatur
e46ab1c580 cr 2023-12-27 17:11:29 -05:00
Bagatur
80efbad918 fmt 2023-12-27 16:54:10 -05:00
Bagatur
52b9e064ce fmt 2023-12-27 16:52:45 -05:00
Bagatur
ffa2b5732f Merge branch 'master' into bagatur/stuff_docs_lcel 2023-12-27 16:45:42 -05:00
Bagatur
50c9d652f4 fix 2023-12-27 16:45:34 -05:00
Bagatur
ded7fa3df7 new syntax 2023-12-27 16:33:15 -05:00
Bagatur
958f75df10 Merge branch 'master' into bagatur/stuff_docs_lcel 2023-12-27 14:49:46 -05:00
Bagatur
c063a87aec fmt 2023-12-27 12:57:07 -05:00
Bagatur
2ca8f94d07 fmt 2023-12-27 12:54:46 -05:00
Bagatur
f55adeaf8e Merge branch 'master' into bagatur/stuff_docs_lcel 2023-12-27 12:49:40 -05:00
Bagatur
da0d18c184 fmt 2023-12-27 12:49:29 -05:00
Bagatur
531c68dc5f with_name 2023-12-27 11:43:07 -05:00
Bagatur
7d6e8d6ed5 fmt 2023-12-27 11:06:10 -05:00
Bagatur
78ff12812c fmt 2023-12-27 10:48:38 -05:00
Bagatur
66c3c4fb4b cr 2023-12-27 10:14:45 -05:00
Bagatur
e01a3458e9 cr 2023-12-27 01:23:36 -05:00
Bagatur
581172dc51 fmt 2023-12-26 19:57:15 -05:00
Bagatur
7ba35b0ff4 fmt 2023-12-26 19:11:09 -05:00
Bagatur
0725ed1ee5 fmt 2023-12-26 19:05:45 -05:00
Bagatur
4edc79ad04 fmt 2023-12-26 19:04:07 -05:00
Bagatur
11d14c8b95 fmt 2023-12-26 18:56:52 -05:00
Bagatur
d06d6b99e3 fmt 2023-12-26 18:24:38 -05:00
Bagatur
be4c76cd87 fmt 2023-12-26 17:46:33 -05:00
Bagatur
ad3bf1fa94 fmt 2023-12-26 17:27:42 -05:00
Bagatur
299851bcac fmt 2023-12-26 17:25:06 -05:00
Bagatur
4747cfafee fmt 2023-12-26 17:21:35 -05:00
Bagatur
ba1b41c54d fmt 2023-12-26 17:10:43 -05:00
Bagatur
7b2886a0a2 fmt 2023-12-26 16:32:55 -05:00
Bagatur
6859509a17 fmt 2023-12-26 16:15:11 -05:00
Bagatur
d9bbe2f144 fmt 2023-12-26 15:54:40 -05:00
Bagatur
5d1ca93cbf fmt 2023-12-26 15:34:11 -05:00
Bagatur
ce1a1b5892 merge 2023-12-26 15:30:09 -05:00
Bagatur
93c17e0783 fmt 2023-12-26 13:57:42 -05:00
Bagatur
7ad09ce887 refactor 2023-12-26 13:47:52 -05:00
Bagatur
35b0a48635 fmt 2023-12-26 13:11:56 -05:00
Bagatur
01739a7b23 fmt 2023-12-26 12:08:48 -05:00
Bagatur
6f3c0ee781 rerank 2023-12-26 12:06:08 -05:00
Bagatur
fb43f69997 fmt 2023-12-26 10:54:18 -05:00
Bagatur
f8cae684b4 fix 2023-12-26 10:49:06 -05:00
Bagatur
107a02bf51 collapse 2023-12-26 10:16:46 -05:00
Bagatur
011908a442 fmt 2023-12-22 18:17:22 -05:00
Bagatur
d418259865 wip 2023-12-22 16:23:05 -05:00
Bagatur
a732063a77 fmt 2023-12-22 15:30:00 -05:00
Bagatur
492506cf82 fmt 2023-12-22 14:45:06 -05:00
Bagatur
00b0f6586e refine 2023-12-22 14:30:25 -05:00
Bagatur
b50c9e0ae9 fmt 2023-12-22 13:43:19 -05:00
Bagatur
87ff61acde fmt 2023-12-22 13:42:19 -05:00
Bagatur
a515261753 fmt 2023-12-22 13:41:17 -05:00
Bagatur
43471f36c4 rfc 2023-12-22 13:24:23 -05:00
8 changed files with 794 additions and 45 deletions

View File

@@ -799,7 +799,7 @@ class Runnable(Generic[Input, Output], ABC):
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
# Sadly Unpack is not well-supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""
@@ -814,6 +814,12 @@ class Runnable(Generic[Input, Output], ABC):
kwargs={},
)
def with_name(self, run_name: str) -> Runnable[Input, Output]:
"""
Bind run_name to a Runnable, returning a new Runnable.
"""
return self.with_config(run_name=run_name)
def with_listeners(
self,
*,

View File

@@ -1,10 +1,19 @@
"""Different ways to combine documents."""
from langchain.chains.combine_documents.map_reduce import (
create_map_documents_chain,
create_map_reduce_documents_chain,
)
from langchain.chains.combine_documents.map_rerank import (
create_map_rerank_documents_chain,
)
from langchain.chains.combine_documents.reduce import (
acollapse_docs,
collapse_docs,
create_collapse_documents_chain,
split_list_of_docs,
)
from langchain.chains.combine_documents.refine import create_refine_documents_chain
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
__all__ = [
@@ -12,4 +21,9 @@ __all__ = [
"collapse_docs",
"split_list_of_docs",
"create_stuff_documents_chain",
"create_map_documents_chain",
"create_map_rerank_documents_chain",
"create_map_reduce_documents_chain",
"create_refine_documents_chain",
"create_collapse_documents_chain",
]

View File

@@ -1,10 +1,11 @@
"""Base interface for chains combining documents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables.config import RunnableConfig
@@ -15,19 +16,62 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
# --- Constants --- #
DOCUMENTS_KEY = "context"
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
# --- Defaults --- #
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
def _validate_prompt(prompt: BasePromptTemplate) -> None:
if DOCUMENTS_KEY not in prompt.input_variables:
# --- Utilities (private) --- #
def validate_prompt(prompt: BasePromptTemplate, expected_inputs: Sequence[str]) -> None:
missing_keys = set(expected_inputs).difference(prompt.input_variables)
if missing_keys:
raise ValueError(
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
f"with input variables: {prompt.input_variables}"
f"Prompt must accept {expected_inputs} as an input variables. Received "
f"prompt with input variables: {prompt.input_variables}"
)
def format_documents(
documents: Sequence[Document],
document_prompt: BasePromptTemplate,
document_separator: str,
) -> str:
return document_separator.join(
format_document(doc, document_prompt) for doc in documents
)
def format_document_inputs(
inputs: Dict[str, Any],
document_prompt: BasePromptTemplate,
*,
document_separator: str = "\n\n",
) -> Dict[str, Any]:
docs_val = inputs[DOCUMENTS_KEY]
docs = docs_val if isinstance(docs_val, list) else [docs_val]
inputs[DOCUMENTS_KEY] = format_documents(docs, document_prompt, document_separator)
return inputs
def format_document_inputs_as_list(
inputs: Dict[str, Any],
document_prompt: BasePromptTemplate,
) -> List[Dict[str, Any]]:
docs = inputs.pop(DOCUMENTS_KEY)
return [
{DOCUMENTS_KEY: format_document(doc, document_prompt), **inputs} for doc in docs
]
# --- Legacy Chain --- #
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents.

View File

@@ -1,18 +1,207 @@
"""Combining documents by mapping a chain over them first, then combining results."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
format_document_inputs,
validate_prompt,
)
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
# --- LCEL Runnable chains --- #
def create_map_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
document_prompt: Optional[BasePromptTemplate] = None,
) -> Runnable[Dict[str, Any], List[Document]]:
"""Create a chain that passes each Document to an LLM and creates a new list of Documents.
Args:
llm: Language model to use for mapping document contents.
prompt: Prompt to use for mapping document contents. Must accept "context" as
one of the input variables. Each document will be passed in as "context".
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
Returns:
An LCEL `Runnable` chain.
Expects a dictionary as input. Input must contain "context" key with a list of
Documents.
Returns a list of Documents, with the contents of each document being the output
of passing the corresponding input document to the model. Document order is
preserved.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_map_documents_chain
llm = ChatOpenAI(model="gpt-3.5-turbo")
extract_prompt = ChatPromptTemplate.from_template(
[
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
("human", "{question}"),
]
)
map_documents_chain = create_map_documents_chain(llm, extract_prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
map_documents_chain.invoke({"context": docs, "question": "Who loves green?"})
""" # noqa: E501
validate_prompt(prompt, (DOCUMENTS_KEY,))
# Runnable: Dict with single doc -> updated page content.
map_content = (
RunnableLambda(cast(Callable, format_document_inputs))
.bind(document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT)
.pipe(prompt, llm, StrOutputParser(), name="map_content")
)
# Runnable: Dict with single doc -> updated doc.
map_doc = (
RunnablePassthrough.assign(page_content=map_content)
.with_name("assign_page_content")
.pipe(_compile_document, name="map_document")
)
# Runnable: Dict with many docs -> many dicts each with one doc.
format_as_list = RunnableLambda(_format_input_as_list)
# Runnable: Dict with many docs -> updated docs.
return format_as_list.pipe(map_doc.map(), name="map_documents_chain")
def create_map_reduce_documents_chain(
map_documents_chain: Runnable[Dict[str, Any], List[Document]],
reduce_documents_chain: Runnable[Dict[str, Any], Any],
*,
collapse_documents_chain: Optional[Runnable[Dict[str, Any], List[Document]]] = None,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain that first maps the contents of each document then reduces them.
Args:
map_documents_chain: Runnable chain for applying some function to the
contents of each document. Should accept dictionary as input and output a
list of Documents.
reduce_documents_chain: Runnable chain for reducing a list of Documents to a
single output. Should accept dictionary as input and is expected to read
the list of Documents from the "context" key.
collapse_documents_chain: Optional Runnable chain for consolidating a list of
Documents until the cumulative token size of all Documents is below some
token limit. Should accept dictionary as input and is expected to read the
list of Documents from the "context" key. If None, collapse step will not
be included in final chain. Else will be run after the map_documents_chain
and before the reduce_documents_chain.
Returns:
An LCEL `Runnable` chain.
Expects a dictionary as input with a list of `Document`s being passed under
the "context" key.
Return type matches the reduce_documents_chain return type.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import (
create_collapse_documents_chain,
create_map_documents_chain,
create_map_reduce_documents_chain,
create_stuff_documents_chain,
)
llm = ChatOpenAI(model="gpt-3.5-turbo")
extract_prompt = ChatPromptTemplate.from_template(
[
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
("human", "{question}"),
]
)
map_documents_chain = create_map_documents_chain(llm, extract_prompt)
collapse_documents_chain = create_collapse_documents_chain(llm, extract_prompt, token_max=4000)
answer_prompt = ChatPromptTemplate.from_template(
[
("system", "Answer the user question using the following context:\n\n{context}"),
("human", "{question}"),
]
)
reduce_documents_chain = create_stuff_documents_chain(llm, answer_prompt)
map_reduce_documents_chain = create_map_reduce_documents_chain(
map_documents_chain,
reduce_documents_chain,
collapse_documents_chain=collapse_documents_chain
)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
map_reduce_documents_chain.invoke({"context": docs, "question": "Who loves green?"})
""" # noqa: E501
assign_mapped_docs = RunnablePassthrough.assign(context=map_documents_chain)
if not collapse_documents_chain:
return assign_mapped_docs.pipe(
reduce_documents_chain, name="map_reduce_documents_chain"
)
else:
return assign_mapped_docs.assign(context=collapse_documents_chain).pipe(
reduce_documents_chain, name="map_reduce_documents_chain"
)
# --- Helper methods for LCEL Runnable chains --- #
def _compile_document(inputs: Dict[str, Any]) -> Document:
doc = inputs[DOCUMENTS_KEY]
return Document(page_content=inputs["page_content"], metadata=doc.metadata)
def _format_input_as_list(inputs: Dict[str, Any]) -> List[dict]:
docs = inputs.pop(DOCUMENTS_KEY)
return [{DOCUMENTS_KEY: doc, **inputs} for doc in docs]
# --- Legacy Chain --- #
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then combining results.

View File

@@ -2,17 +2,186 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
import re
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
create_model,
root_validator,
)
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
)
from langchain_core.runnables.config import RunnableConfig
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
format_document_inputs_as_list,
validate_prompt,
)
from langchain.chains.llm import LLMChain
from langchain.output_parsers.regex import RegexParser
# --- LCEL Runnable chains --- #
def create_map_rerank_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseOutputParser[Dict[str, Any]]] = None,
document_prompt: Optional[BasePromptTemplate] = None,
) -> Runnable[Dict[str, Any], _MapRerankOutput]:
"""Create a chain that writes candidates answers for each doc and selects the best one.
Args:
llm: Language model to use for responding.
prompt: Prompt to use for answering and scoring. Must accept "context"
as one of the input variables. Should work well with the given
output_parser so that a score and answer can be extracted from each
model output. Scores must be comparable (e.g. floats or ints) and a
**higher score** should indicate a more relevant answer.
output_parser: Output parser to use. Must return a dictionary containing a
"score" and "answer" key. If none is provided, will default to a simple
regex parser that searches for "Answer:" and "Score:" keywords in the
model output.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
Returns:
An LCEL `Runnable` chain.
Expects a dictionary as input with a list of `Document`s being passed under
the "context" key.
Returns a dictionary as output that looks like:
{
"top_answer": "top-score answer string",
"all_answers": [{"score": ..., "answer": "..."}, {"score": ..., "answer": "..."}, ...]
}
The scored answers in "all_answers" preserve the order of the input Documents.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.output_parsers import SimpleJsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_map_rerank_documents_chain
llm = ChatOpenAI(model="gpt-3.5-turbo-1106").bind(
response_format={"type": "json_object"}
)
prompt = ChatPromptTemplate.from_message(
[
("system", '''Answer the user question using the given context. Score the
answer from 1-5 based on how relevant the context is to the question. Return a JSON string with
keys 'answer' and 'score.\n\n{context}'''),
("human", "{question}"),
]
)
chain = create_map_rerank_documents_chain(llm, prompt, output_parser=SimpleJsonOutputParser())
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs, "question": "Who loves green?"})
# -> {
# 'all_answers':
# [
# {'answer': 'This context does not provide any information about who loves green.', 'score': 1},
# {'answer': 'Jamal loves green but not as much as he loves orange', 'score': 5}
# ],
# 'top_answer': 'Jamal loves green but not as much as he loves orange'
# }
""" # noqa: E501
validate_prompt(prompt, (DOCUMENTS_KEY,))
_output_parser = output_parser or (StrOutputParser() | _default_regex)
# Runnable: Dict with single doc -> {"answer": ..., "score": ...}
answer_chain = prompt.pipe(llm, _output_parser, name="answer_and_score")
# Runnable: Dict with many docs -> [{"answer": ..., "score": ...}, ...]
map_chain = (
RunnableLambda(cast(Callable, format_document_inputs_as_list))
.bind(document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT)
.pipe(answer_chain.map(), name="answer_and_score_all")
)
# Runnable: Dict with many docs -> {"top_answer": "...", "all_answers": [...]}
return (
RunnableParallel(all_answers=map_chain)
.assign(top_answer=_top_answer)
.with_name("map_rerank_documents_chain")
.with_types(output_type=_MapRerankOutput)
)
### --- Helper methods for LCEL Runnable chain --- ###
def _top_answer(results: Dict[str, Any]) -> str:
return max(results["all_answers"], key=lambda x: float(x["score"]))["answer"]
def _default_regex(text: str) -> dict:
regex = r"(?:Answer:\s)?(.*?)\s*Score:\s(.*)"
match = re.search(regex, text, re.IGNORECASE + re.DOTALL)
if match is None:
raise OutputParserException(
ValueError(
f"Model output did not match expected regex. Expected {regex}, "
f"received: {text}."
)
)
else:
return {"answer": match.group(1), "score": match.group(2)}
class _MapRerankOutput(BaseModel):
top_answer: str = Field(..., description="The highest-scored answer.")
all_answers: List[Dict[str, Any]] = Field(
...,
description="All answers and scores, in the same order as the input documents.",
)
# --- Legacy Chain --- #
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Combining documents by mapping a chain over them, then reranking results.

View File

@@ -2,13 +2,204 @@
from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Protocol,
Sequence,
Tuple,
Union,
cast,
)
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Extra
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
format_document_inputs,
format_documents,
validate_prompt,
)
from langchain.pydantic_v1 import BaseModel
# --- LCEL Runnable chains --- #
def create_collapse_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
token_max: int,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
token_len_func: Optional[Callable[[str], int]] = None,
) -> Runnable[Dict[str, Any], List[Document]]:
"""Create
Args:
llm: Language model to use for collapsing document contents.
prompt: Prompt to use for collapsing document contents. Must accept "context" as
one of the input variables. The formatted documents will be passed in as
"context".
token_max: The maximum cumulative token length that the list of Documents can
have.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
document_separator: Separator string to use for joining formatted document
strings.
token_len_func: Optional Callable for computing token length of a string. Should
take string as input and output an int. If None, will default to
`llm.get_num_tokens` if the `llm` has this method, otherwise will use `len`.
Returns:
An LCEL `Runnable` chain.
Expects a dictionary as input. Input must contain "context" key with a list of
Documents.
Returns a list of Documents whose cumulative token length when formatted as
strings is below `token_max`.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_collapse_documents_chain
llm = ChatOpenAI(model="gpt-3.5-turbo")
extract_prompt = ChatPromptTemplate.from_messages(
[
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
("human", "{question}"),
]
)
collapse_documents_chain = create_collapse_documents_chain(llm, extract_prompt, token_max=1000)
docs = [
Document(page_content="I love yellow. " * 200),
Document(page_content="You love green. " * 200),
]
collapse_documents_chain.invoke({"context": docs, "question": "What color do you love?"})
""" # noqa: E501
validate_prompt(prompt, (DOCUMENTS_KEY,))
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_token_len_func: Callable = token_len_func or getattr(llm, "get_num_tokens", len) # type: ignore # noqa: E501
# Runnable: Dict with many docs -> reduced string.
reduce_content = (
RunnableLambda(cast(Callable, format_document_inputs))
.bind(
document_prompt=_document_prompt,
document_separator=document_separator,
)
.pipe(prompt, llm, StrOutputParser(), name="reduce_content")
)
# Runnable: Dict with many docs -> single Document.
reduce_docs = RunnableParallel(
page_content=reduce_content,
metadata=_combine_metadata,
).pipe(_dict_to_document, name="reduce_documents")
def partition_docs(inputs: dict) -> List[dict]:
docs = inputs.pop(DOCUMENTS_KEY)
partitions = split_list_of_docs(
docs,
_docs_len,
token_max,
token_len_func=_token_len_func,
document_prompt=_document_prompt,
document_separator=document_separator,
)
return [{DOCUMENTS_KEY: docs, **inputs} for docs in partitions]
partition_and_reduce = RunnableLambda(partition_docs).pipe(
reduce_docs.map(), name="partition_and_reduce"
)
def collapse_loop(inputs: dict) -> Union[List[Document], Runnable]:
"""Recursive collapse loop.
If the cumulative token length of the documents exceeds the token_max,
append another partition and reduce step to the Runnable sequence. Otherwise
return the docs.
"""
docs = inputs[DOCUMENTS_KEY]
curr_len = _docs_len(
docs, _token_len_func, _document_prompt, document_separator
)
if curr_len > token_max:
return RunnablePassthrough.assign(context=partition_and_reduce).pipe(
collapse_loop, name="collapse_step"
)
else:
return docs
return (
RunnableLambda(cast(Callable, collapse_loop))
.with_name("collapse_documents_chain")
.with_types(output_type=_CollapseOutputType) # type: ignore
)
# --- Helper methods for LCEL Runnable chains --- #
def _combine_metadata(inputs: Dict[str, Any]) -> Dict[Any, str]:
docs = inputs[DOCUMENTS_KEY]
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return combined_metadata
def _dict_to_document(inputs: Dict[str, Any]) -> Document:
return Document(page_content=inputs["page_content"], metadata=inputs["metadata"])
def _docs_len(
docs: Sequence[Document],
token_len_func: Callable[[str], int],
document_prompt: BasePromptTemplate,
document_separator: str,
) -> int:
return token_len_func(format_documents(docs, document_prompt, document_separator))
class _CollapseOutputType(BaseModel):
__root__: List[Document]
# --- Legacy Chain --- #
class CombineDocsProtocol(Protocol):

View File

@@ -2,22 +2,157 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import (
Runnable,
RunnableLambda,
RunnablePassthrough,
)
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DOCUMENTS_KEY,
INTERMEDIATE_STEPS_KEY,
BaseCombineDocumentsChain,
validate_prompt,
)
from langchain.chains.llm import LLMChain
# --- LCEL Runnable chain --- #
def _get_default_document_prompt() -> PromptTemplate:
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
OUTPUT_KEY = "output"
def create_refine_documents_chain(
llm: LanguageModelLike,
initial_prompt: BasePromptTemplate,
refine_prompt: BasePromptTemplate,
*,
document_prompt: Optional[BasePromptTemplate] = None,
) -> Runnable:
"""Create a chain that feeds documents to a model one at a time and updates the output.
Args:
llm: Language model to use for responding.
initial_prompt: The prompt to use on the first document. Must accept "context"
as one of the input variables. The first document will be passed in as
"context".
refine_prompt: The prompt to use on all subsequent documents. Must accept
"context" and "output" as input variables. A document will be passed in as
"context" and the refined output up to this iteration will be passed in as
"output.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
Returns:
An LCEL `Runnable` chain.
Expects a dictionary as input with a list of `Document`s being passed under
the "context" key.
Returns a dictionary as output. The output dictionary contains two keys,
"output" and "intermediate_steps". "output" contains the final output.
"intermediate_steps" contains the list of intermediate output
strings generated by the chain, in the order that they were generated.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents.refine import create_refine_documents_chain
initial_prompt = ChatPromptTemplate.from_messages([
("system", "Summarize this information: {context}"),
])
refine_prompt = ChatPromptTemplate.from_messages([
("system", '''You are summarizing a long document one page at a time. \
You have summarized part of the document. Given the next page, update your \
summary. Respond with only the updated summary and no other text. \
Here is your working summary:\n\n{output}.'''),
("human", "Here is the next page:\n\n{context}")
])
llm = ChatOpenAI(model="gpt-3.5-turbo")
chain = create_refine_documents_chain(llm, initial_prompt, refine_prompt, llm,)
# chain.invoke({"context": docs})
""" # noqa: E501
validate_prompt(initial_prompt, (DOCUMENTS_KEY,))
validate_prompt(refine_prompt, (DOCUMENTS_KEY, OUTPUT_KEY))
format_doc: Runnable = RunnableLambda(cast(Callable, _get_and_format_doc)).bind(
document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT
)
# Runnable: Dict with many docs -> answer string based on first doc
initial_response = format_doc.pipe(
initial_prompt, llm, StrOutputParser(), name="initial_response"
)
# Runnable: Dict with many docs, current answer, intermediate_steps
# -> updated answer based on next doc
refine_response = format_doc.pipe(
refine_prompt, llm, StrOutputParser(), name="refine_response"
)
# Runnable: Update intermediates_steps based on last output, in parallel update
# output.
refine_step = RunnablePassthrough.assign(
intermediate_steps=_update_intermediate_steps,
output=refine_response,
)
# Function that returns a sequence of refine_steps equal to len(docs) - 1.
refine_loop = RunnableLambda(cast(Callable, _runnable_loop)).bind(
step=refine_step, step_name="refine_step_{iteration}"
)
# Runnable: Dict with many docs -> {"answer": "...", "intermediate_steps": [...]}
return (
RunnablePassthrough.assign(output=initial_response)
.pipe(refine_loop)
.pick([OUTPUT_KEY, INTERMEDIATE_STEPS_KEY])
.with_name("refine_documents_chain")
)
# --- Helpers for LCEL Runnable chain --- #
def _get_and_format_doc(inputs: dict, document_prompt: BasePromptTemplate) -> dict:
intermediate_steps = inputs.pop(INTERMEDIATE_STEPS_KEY, [])
doc = inputs[DOCUMENTS_KEY][len(intermediate_steps)]
inputs[DOCUMENTS_KEY] = format_document(doc, document_prompt)
return inputs
def _runnable_loop(inputs: dict, step: Runnable, step_name: str) -> Runnable:
if len(inputs[DOCUMENTS_KEY]) < 2:
return RunnablePassthrough()
chain: Runnable = step.with_name(step_name.format(iteration=1))
for iteration in range(2, len(inputs[DOCUMENTS_KEY])):
chain |= step.with_name(step_name.format(iteration=iteration))
return chain
def _update_intermediate_steps(inputs: dict) -> list:
return inputs.get(INTERMEDIATE_STEPS_KEY, []) + [inputs[OUTPUT_KEY]]
# --- Legacy Chain --- #
class RefineDocumentsChain(BaseCombineDocumentsChain):
@@ -81,7 +216,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
initial_response_name: str
"""The variable name to format the initial response in when refining."""
document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
)
"""Prompt to use to format each document, gets passed to `format_document`."""
return_intermediate_steps: bool = False

View File

@@ -1,12 +1,12 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables import Runnable, RunnableLambda
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
@@ -14,7 +14,8 @@ from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
format_document_inputs,
validate_prompt,
)
from langchain.chains.llm import LLMChain
@@ -27,13 +28,13 @@ def create_stuff_documents_chain(
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model.
"""Create a chain that passes a list of documents to a model.
Args:
llm: Language model.
llm: Language model to use for responding.
prompt: Prompt template. Must contain input variable "context", which will be
used for passing in the formatted documents.
output_parser: Output parser. Defaults to StrOutputParser.
output_parser: Output parser. Defaults to `StrOutputParser`.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
@@ -43,9 +44,12 @@ def create_stuff_documents_chain(
document_separator: String separator to use between formatted document strings.
Returns:
An LCEL Runnable. The input is a dictionary that must have a "context" key that
maps to a List[Document], and any other input variables expected in the prompt.
The Runnable return type depends on output_parser used.
An LCEL `Runnable` chain .
Expects a dictionary as input with a list of `Document`s being passed under
the "context" key.
Return type depends on the `output_parser` used.
Example:
.. code-block:: python
@@ -58,7 +62,10 @@ def create_stuff_documents_chain(
from langchain.chains.combine_documents import create_stuff_documents_chain
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\n\n{context}")]
[
("system", "Use the following context to answer the user's questions:\n\n{context}"),
("human", "{question}"),
]
)
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
@@ -68,26 +75,20 @@ def create_stuff_documents_chain(
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs})
chain.invoke({"context": docs, "question": "What are everyone's favorite colors?"})
""" # noqa: E501
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
validate_prompt(prompt, (DOCUMENTS_KEY,))
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
RunnableLambda(cast(Callable, format_document_inputs))
.bind(
document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT,
document_separator=document_separator,
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
.pipe(prompt, llm, _output_parser, name="stuff_documents_chain")
)
class StuffDocumentsChain(BaseCombineDocumentsChain):