mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
51 Commits
langchain=
...
bagatur/st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1ecf827c3 | ||
|
|
5a40db5387 | ||
|
|
e46ab1c580 | ||
|
|
80efbad918 | ||
|
|
52b9e064ce | ||
|
|
ffa2b5732f | ||
|
|
50c9d652f4 | ||
|
|
ded7fa3df7 | ||
|
|
958f75df10 | ||
|
|
c063a87aec | ||
|
|
2ca8f94d07 | ||
|
|
f55adeaf8e | ||
|
|
da0d18c184 | ||
|
|
531c68dc5f | ||
|
|
7d6e8d6ed5 | ||
|
|
78ff12812c | ||
|
|
66c3c4fb4b | ||
|
|
e01a3458e9 | ||
|
|
581172dc51 | ||
|
|
7ba35b0ff4 | ||
|
|
0725ed1ee5 | ||
|
|
4edc79ad04 | ||
|
|
11d14c8b95 | ||
|
|
d06d6b99e3 | ||
|
|
be4c76cd87 | ||
|
|
ad3bf1fa94 | ||
|
|
299851bcac | ||
|
|
4747cfafee | ||
|
|
ba1b41c54d | ||
|
|
7b2886a0a2 | ||
|
|
6859509a17 | ||
|
|
d9bbe2f144 | ||
|
|
5d1ca93cbf | ||
|
|
ce1a1b5892 | ||
|
|
93c17e0783 | ||
|
|
7ad09ce887 | ||
|
|
35b0a48635 | ||
|
|
01739a7b23 | ||
|
|
6f3c0ee781 | ||
|
|
fb43f69997 | ||
|
|
f8cae684b4 | ||
|
|
107a02bf51 | ||
|
|
011908a442 | ||
|
|
d418259865 | ||
|
|
a732063a77 | ||
|
|
492506cf82 | ||
|
|
00b0f6586e | ||
|
|
b50c9e0ae9 | ||
|
|
87ff61acde | ||
|
|
a515261753 | ||
|
|
43471f36c4 |
@@ -799,7 +799,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
# Sadly Unpack is not well-supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
@@ -814,6 +814,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
def with_name(self, run_name: str) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind run_name to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return self.with_config(run_name=run_name)
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""Different ways to combine documents."""
|
||||
|
||||
from langchain.chains.combine_documents.map_reduce import (
|
||||
create_map_documents_chain,
|
||||
create_map_reduce_documents_chain,
|
||||
)
|
||||
from langchain.chains.combine_documents.map_rerank import (
|
||||
create_map_rerank_documents_chain,
|
||||
)
|
||||
from langchain.chains.combine_documents.reduce import (
|
||||
acollapse_docs,
|
||||
collapse_docs,
|
||||
create_collapse_documents_chain,
|
||||
split_list_of_docs,
|
||||
)
|
||||
from langchain.chains.combine_documents.refine import create_refine_documents_chain
|
||||
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
|
||||
|
||||
__all__ = [
|
||||
@@ -12,4 +21,9 @@ __all__ = [
|
||||
"collapse_docs",
|
||||
"split_list_of_docs",
|
||||
"create_stuff_documents_chain",
|
||||
"create_map_documents_chain",
|
||||
"create_map_rerank_documents_chain",
|
||||
"create_map_reduce_documents_chain",
|
||||
"create_refine_documents_chain",
|
||||
"create_collapse_documents_chain",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate, format_document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -15,19 +16,62 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
|
||||
# --- Constants --- #
|
||||
DOCUMENTS_KEY = "context"
|
||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
# --- Defaults --- #
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
||||
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
|
||||
|
||||
|
||||
def _validate_prompt(prompt: BasePromptTemplate) -> None:
|
||||
if DOCUMENTS_KEY not in prompt.input_variables:
|
||||
# --- Utilities (private) --- #
|
||||
|
||||
|
||||
def validate_prompt(prompt: BasePromptTemplate, expected_inputs: Sequence[str]) -> None:
|
||||
missing_keys = set(expected_inputs).difference(prompt.input_variables)
|
||||
if missing_keys:
|
||||
raise ValueError(
|
||||
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
|
||||
f"with input variables: {prompt.input_variables}"
|
||||
f"Prompt must accept {expected_inputs} as an input variables. Received "
|
||||
f"prompt with input variables: {prompt.input_variables}"
|
||||
)
|
||||
|
||||
|
||||
def format_documents(
|
||||
documents: Sequence[Document],
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
) -> str:
|
||||
return document_separator.join(
|
||||
format_document(doc, document_prompt) for doc in documents
|
||||
)
|
||||
|
||||
|
||||
def format_document_inputs(
|
||||
inputs: Dict[str, Any],
|
||||
document_prompt: BasePromptTemplate,
|
||||
*,
|
||||
document_separator: str = "\n\n",
|
||||
) -> Dict[str, Any]:
|
||||
docs_val = inputs[DOCUMENTS_KEY]
|
||||
docs = docs_val if isinstance(docs_val, list) else [docs_val]
|
||||
inputs[DOCUMENTS_KEY] = format_documents(docs, document_prompt, document_separator)
|
||||
return inputs
|
||||
|
||||
|
||||
def format_document_inputs_as_list(
|
||||
inputs: Dict[str, Any],
|
||||
document_prompt: BasePromptTemplate,
|
||||
) -> List[Dict[str, Any]]:
|
||||
docs = inputs.pop(DOCUMENTS_KEY)
|
||||
return [
|
||||
{DOCUMENTS_KEY: format_document(doc, document_prompt), **inputs} for doc in docs
|
||||
]
|
||||
|
||||
|
||||
# --- Legacy Chain --- #
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents.
|
||||
|
||||
|
||||
@@ -1,18 +1,207 @@
|
||||
"""Combining documents by mapping a chain over them first, then combining results."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_PROMPT,
|
||||
DOCUMENTS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
format_document_inputs,
|
||||
validate_prompt,
|
||||
)
|
||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
# --- LCEL Runnable chains --- #
|
||||
|
||||
|
||||
def create_map_documents_chain(
|
||||
llm: LanguageModelLike,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
) -> Runnable[Dict[str, Any], List[Document]]:
|
||||
"""Create a chain that passes each Document to an LLM and creates a new list of Documents.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for mapping document contents.
|
||||
prompt: Prompt to use for mapping document contents. Must accept "context" as
|
||||
one of the input variables. Each document will be passed in as "context".
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
|
||||
Returns:
|
||||
An LCEL `Runnable` chain.
|
||||
|
||||
Expects a dictionary as input. Input must contain "context" key with a list of
|
||||
Documents.
|
||||
|
||||
Returns a list of Documents, with the contents of each document being the output
|
||||
of passing the corresponding input document to the model. Document order is
|
||||
preserved.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents import create_map_documents_chain
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
extract_prompt = ChatPromptTemplate.from_template(
|
||||
[
|
||||
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
map_documents_chain = create_map_documents_chain(llm, extract_prompt)
|
||||
|
||||
docs = [
|
||||
Document(page_content="Jesse loves red but not yellow"),
|
||||
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||
]
|
||||
|
||||
map_documents_chain.invoke({"context": docs, "question": "Who loves green?"})
|
||||
""" # noqa: E501
|
||||
validate_prompt(prompt, (DOCUMENTS_KEY,))
|
||||
|
||||
# Runnable: Dict with single doc -> updated page content.
|
||||
map_content = (
|
||||
RunnableLambda(cast(Callable, format_document_inputs))
|
||||
.bind(document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT)
|
||||
.pipe(prompt, llm, StrOutputParser(), name="map_content")
|
||||
)
|
||||
|
||||
# Runnable: Dict with single doc -> updated doc.
|
||||
map_doc = (
|
||||
RunnablePassthrough.assign(page_content=map_content)
|
||||
.with_name("assign_page_content")
|
||||
.pipe(_compile_document, name="map_document")
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs -> many dicts each with one doc.
|
||||
format_as_list = RunnableLambda(_format_input_as_list)
|
||||
|
||||
# Runnable: Dict with many docs -> updated docs.
|
||||
return format_as_list.pipe(map_doc.map(), name="map_documents_chain")
|
||||
|
||||
|
||||
def create_map_reduce_documents_chain(
|
||||
map_documents_chain: Runnable[Dict[str, Any], List[Document]],
|
||||
reduce_documents_chain: Runnable[Dict[str, Any], Any],
|
||||
*,
|
||||
collapse_documents_chain: Optional[Runnable[Dict[str, Any], List[Document]]] = None,
|
||||
) -> Runnable[Dict[str, Any], Any]:
|
||||
"""Create a chain that first maps the contents of each document then reduces them.
|
||||
|
||||
Args:
|
||||
map_documents_chain: Runnable chain for applying some function to the
|
||||
contents of each document. Should accept dictionary as input and output a
|
||||
list of Documents.
|
||||
reduce_documents_chain: Runnable chain for reducing a list of Documents to a
|
||||
single output. Should accept dictionary as input and is expected to read
|
||||
the list of Documents from the "context" key.
|
||||
collapse_documents_chain: Optional Runnable chain for consolidating a list of
|
||||
Documents until the cumulative token size of all Documents is below some
|
||||
token limit. Should accept dictionary as input and is expected to read the
|
||||
list of Documents from the "context" key. If None, collapse step will not
|
||||
be included in final chain. Else will be run after the map_documents_chain
|
||||
and before the reduce_documents_chain.
|
||||
|
||||
Returns:
|
||||
An LCEL `Runnable` chain.
|
||||
|
||||
Expects a dictionary as input with a list of `Document`s being passed under
|
||||
the "context" key.
|
||||
|
||||
Return type matches the reduce_documents_chain return type.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents import (
|
||||
create_collapse_documents_chain,
|
||||
create_map_documents_chain,
|
||||
create_map_reduce_documents_chain,
|
||||
create_stuff_documents_chain,
|
||||
)
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
extract_prompt = ChatPromptTemplate.from_template(
|
||||
[
|
||||
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
map_documents_chain = create_map_documents_chain(llm, extract_prompt)
|
||||
collapse_documents_chain = create_collapse_documents_chain(llm, extract_prompt, token_max=4000)
|
||||
|
||||
answer_prompt = ChatPromptTemplate.from_template(
|
||||
[
|
||||
("system", "Answer the user question using the following context:\n\n{context}"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
reduce_documents_chain = create_stuff_documents_chain(llm, answer_prompt)
|
||||
map_reduce_documents_chain = create_map_reduce_documents_chain(
|
||||
map_documents_chain,
|
||||
reduce_documents_chain,
|
||||
collapse_documents_chain=collapse_documents_chain
|
||||
)
|
||||
|
||||
docs = [
|
||||
Document(page_content="Jesse loves red but not yellow"),
|
||||
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||
]
|
||||
|
||||
map_reduce_documents_chain.invoke({"context": docs, "question": "Who loves green?"})
|
||||
""" # noqa: E501
|
||||
assign_mapped_docs = RunnablePassthrough.assign(context=map_documents_chain)
|
||||
if not collapse_documents_chain:
|
||||
return assign_mapped_docs.pipe(
|
||||
reduce_documents_chain, name="map_reduce_documents_chain"
|
||||
)
|
||||
else:
|
||||
return assign_mapped_docs.assign(context=collapse_documents_chain).pipe(
|
||||
reduce_documents_chain, name="map_reduce_documents_chain"
|
||||
)
|
||||
|
||||
|
||||
# --- Helper methods for LCEL Runnable chains --- #
|
||||
|
||||
|
||||
def _compile_document(inputs: Dict[str, Any]) -> Document:
|
||||
doc = inputs[DOCUMENTS_KEY]
|
||||
return Document(page_content=inputs["page_content"], metadata=doc.metadata)
|
||||
|
||||
|
||||
def _format_input_as_list(inputs: Dict[str, Any]) -> List[dict]:
|
||||
docs = inputs.pop(DOCUMENTS_KEY)
|
||||
return [{DOCUMENTS_KEY: doc, **inputs} for doc in docs]
|
||||
|
||||
|
||||
# --- Legacy Chain --- #
|
||||
|
||||
|
||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by mapping a chain over them, then combining results.
|
||||
|
||||
@@ -2,17 +2,186 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
)
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_PROMPT,
|
||||
DOCUMENTS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
format_document_inputs_as_list,
|
||||
validate_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
|
||||
# --- LCEL Runnable chains --- #
|
||||
|
||||
|
||||
def create_map_rerank_documents_chain(
|
||||
llm: LanguageModelLike,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[BaseOutputParser[Dict[str, Any]]] = None,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
) -> Runnable[Dict[str, Any], _MapRerankOutput]:
|
||||
"""Create a chain that writes candidates answers for each doc and selects the best one.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for responding.
|
||||
prompt: Prompt to use for answering and scoring. Must accept "context"
|
||||
as one of the input variables. Should work well with the given
|
||||
output_parser so that a score and answer can be extracted from each
|
||||
model output. Scores must be comparable (e.g. floats or ints) and a
|
||||
**higher score** should indicate a more relevant answer.
|
||||
output_parser: Output parser to use. Must return a dictionary containing a
|
||||
"score" and "answer" key. If none is provided, will default to a simple
|
||||
regex parser that searches for "Answer:" and "Score:" keywords in the
|
||||
model output.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
|
||||
Returns:
|
||||
An LCEL `Runnable` chain.
|
||||
|
||||
Expects a dictionary as input with a list of `Document`s being passed under
|
||||
the "context" key.
|
||||
|
||||
Returns a dictionary as output that looks like:
|
||||
{
|
||||
"top_answer": "top-score answer string",
|
||||
"all_answers": [{"score": ..., "answer": "..."}, {"score": ..., "answer": "..."}, ...]
|
||||
}
|
||||
The scored answers in "all_answers" preserve the order of the input Documents.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.output_parsers import SimpleJsonOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents import create_map_rerank_documents_chain
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-1106").bind(
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_message(
|
||||
[
|
||||
("system", '''Answer the user question using the given context. Score the
|
||||
answer from 1-5 based on how relevant the context is to the question. Return a JSON string with
|
||||
keys 'answer' and 'score.\n\n{context}'''),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
chain = create_map_rerank_documents_chain(llm, prompt, output_parser=SimpleJsonOutputParser())
|
||||
|
||||
docs = [
|
||||
Document(page_content="Jesse loves red but not yellow"),
|
||||
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||
]
|
||||
|
||||
chain.invoke({"context": docs, "question": "Who loves green?"})
|
||||
# -> {
|
||||
# 'all_answers':
|
||||
# [
|
||||
# {'answer': 'This context does not provide any information about who loves green.', 'score': 1},
|
||||
# {'answer': 'Jamal loves green but not as much as he loves orange', 'score': 5}
|
||||
# ],
|
||||
# 'top_answer': 'Jamal loves green but not as much as he loves orange'
|
||||
# }
|
||||
""" # noqa: E501
|
||||
validate_prompt(prompt, (DOCUMENTS_KEY,))
|
||||
_output_parser = output_parser or (StrOutputParser() | _default_regex)
|
||||
|
||||
# Runnable: Dict with single doc -> {"answer": ..., "score": ...}
|
||||
answer_chain = prompt.pipe(llm, _output_parser, name="answer_and_score")
|
||||
|
||||
# Runnable: Dict with many docs -> [{"answer": ..., "score": ...}, ...]
|
||||
map_chain = (
|
||||
RunnableLambda(cast(Callable, format_document_inputs_as_list))
|
||||
.bind(document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT)
|
||||
.pipe(answer_chain.map(), name="answer_and_score_all")
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs -> {"top_answer": "...", "all_answers": [...]}
|
||||
return (
|
||||
RunnableParallel(all_answers=map_chain)
|
||||
.assign(top_answer=_top_answer)
|
||||
.with_name("map_rerank_documents_chain")
|
||||
.with_types(output_type=_MapRerankOutput)
|
||||
)
|
||||
|
||||
|
||||
### --- Helper methods for LCEL Runnable chain --- ###
|
||||
|
||||
|
||||
def _top_answer(results: Dict[str, Any]) -> str:
|
||||
return max(results["all_answers"], key=lambda x: float(x["score"]))["answer"]
|
||||
|
||||
|
||||
def _default_regex(text: str) -> dict:
|
||||
regex = r"(?:Answer:\s)?(.*?)\s*Score:\s(.*)"
|
||||
match = re.search(regex, text, re.IGNORECASE + re.DOTALL)
|
||||
if match is None:
|
||||
raise OutputParserException(
|
||||
ValueError(
|
||||
f"Model output did not match expected regex. Expected {regex}, "
|
||||
f"received: {text}."
|
||||
)
|
||||
)
|
||||
else:
|
||||
return {"answer": match.group(1), "score": match.group(2)}
|
||||
|
||||
|
||||
class _MapRerankOutput(BaseModel):
|
||||
top_answer: str = Field(..., description="The highest-scored answer.")
|
||||
all_answers: List[Dict[str, Any]] = Field(
|
||||
...,
|
||||
description="All answers and scores, in the same order as the input documents.",
|
||||
)
|
||||
|
||||
|
||||
# --- Legacy Chain --- #
|
||||
|
||||
|
||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Combining documents by mapping a chain over them, then reranking results.
|
||||
|
||||
@@ -2,13 +2,204 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_PROMPT,
|
||||
DEFAULT_DOCUMENT_SEPARATOR,
|
||||
DOCUMENTS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
format_document_inputs,
|
||||
format_documents,
|
||||
validate_prompt,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
|
||||
# --- LCEL Runnable chains --- #
|
||||
|
||||
|
||||
def create_collapse_documents_chain(
|
||||
llm: LanguageModelLike,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
token_max: int,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||
token_len_func: Optional[Callable[[str], int]] = None,
|
||||
) -> Runnable[Dict[str, Any], List[Document]]:
|
||||
"""Create
|
||||
|
||||
Args:
|
||||
llm: Language model to use for collapsing document contents.
|
||||
prompt: Prompt to use for collapsing document contents. Must accept "context" as
|
||||
one of the input variables. The formatted documents will be passed in as
|
||||
"context".
|
||||
token_max: The maximum cumulative token length that the list of Documents can
|
||||
have.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
document_separator: Separator string to use for joining formatted document
|
||||
strings.
|
||||
token_len_func: Optional Callable for computing token length of a string. Should
|
||||
take string as input and output an int. If None, will default to
|
||||
`llm.get_num_tokens` if the `llm` has this method, otherwise will use `len`.
|
||||
|
||||
Returns:
|
||||
An LCEL `Runnable` chain.
|
||||
|
||||
Expects a dictionary as input. Input must contain "context" key with a list of
|
||||
Documents.
|
||||
|
||||
Returns a list of Documents whose cumulative token length when formatted as
|
||||
strings is below `token_max`.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents import create_collapse_documents_chain
|
||||
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
extract_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "Given a user question, extract the most relevant parts of the following context:\n\n{context}"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
collapse_documents_chain = create_collapse_documents_chain(llm, extract_prompt, token_max=1000)
|
||||
|
||||
docs = [
|
||||
Document(page_content="I love yellow. " * 200),
|
||||
Document(page_content="You love green. " * 200),
|
||||
]
|
||||
collapse_documents_chain.invoke({"context": docs, "question": "What color do you love?"})
|
||||
""" # noqa: E501
|
||||
validate_prompt(prompt, (DOCUMENTS_KEY,))
|
||||
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
||||
_token_len_func: Callable = token_len_func or getattr(llm, "get_num_tokens", len) # type: ignore # noqa: E501
|
||||
|
||||
# Runnable: Dict with many docs -> reduced string.
|
||||
reduce_content = (
|
||||
RunnableLambda(cast(Callable, format_document_inputs))
|
||||
.bind(
|
||||
document_prompt=_document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
.pipe(prompt, llm, StrOutputParser(), name="reduce_content")
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs -> single Document.
|
||||
reduce_docs = RunnableParallel(
|
||||
page_content=reduce_content,
|
||||
metadata=_combine_metadata,
|
||||
).pipe(_dict_to_document, name="reduce_documents")
|
||||
|
||||
def partition_docs(inputs: dict) -> List[dict]:
|
||||
docs = inputs.pop(DOCUMENTS_KEY)
|
||||
partitions = split_list_of_docs(
|
||||
docs,
|
||||
_docs_len,
|
||||
token_max,
|
||||
token_len_func=_token_len_func,
|
||||
document_prompt=_document_prompt,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
return [{DOCUMENTS_KEY: docs, **inputs} for docs in partitions]
|
||||
|
||||
partition_and_reduce = RunnableLambda(partition_docs).pipe(
|
||||
reduce_docs.map(), name="partition_and_reduce"
|
||||
)
|
||||
|
||||
def collapse_loop(inputs: dict) -> Union[List[Document], Runnable]:
|
||||
"""Recursive collapse loop.
|
||||
|
||||
If the cumulative token length of the documents exceeds the token_max,
|
||||
append another partition and reduce step to the Runnable sequence. Otherwise
|
||||
return the docs.
|
||||
"""
|
||||
docs = inputs[DOCUMENTS_KEY]
|
||||
curr_len = _docs_len(
|
||||
docs, _token_len_func, _document_prompt, document_separator
|
||||
)
|
||||
if curr_len > token_max:
|
||||
return RunnablePassthrough.assign(context=partition_and_reduce).pipe(
|
||||
collapse_loop, name="collapse_step"
|
||||
)
|
||||
else:
|
||||
return docs
|
||||
|
||||
return (
|
||||
RunnableLambda(cast(Callable, collapse_loop))
|
||||
.with_name("collapse_documents_chain")
|
||||
.with_types(output_type=_CollapseOutputType) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
# --- Helper methods for LCEL Runnable chains --- #
|
||||
|
||||
|
||||
def _combine_metadata(inputs: Dict[str, Any]) -> Dict[Any, str]:
|
||||
docs = inputs[DOCUMENTS_KEY]
|
||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||
for doc in docs[1:]:
|
||||
for k, v in doc.metadata.items():
|
||||
if k in combined_metadata:
|
||||
combined_metadata[k] += f", {v}"
|
||||
else:
|
||||
combined_metadata[k] = str(v)
|
||||
return combined_metadata
|
||||
|
||||
|
||||
def _dict_to_document(inputs: Dict[str, Any]) -> Document:
|
||||
return Document(page_content=inputs["page_content"], metadata=inputs["metadata"])
|
||||
|
||||
|
||||
def _docs_len(
|
||||
docs: Sequence[Document],
|
||||
token_len_func: Callable[[str], int],
|
||||
document_prompt: BasePromptTemplate,
|
||||
document_separator: str,
|
||||
) -> int:
|
||||
return token_len_func(format_documents(docs, document_prompt, document_separator))
|
||||
|
||||
|
||||
class _CollapseOutputType(BaseModel):
|
||||
__root__: List[Document]
|
||||
|
||||
|
||||
# --- Legacy Chain --- #
|
||||
|
||||
|
||||
class CombineDocsProtocol(Protocol):
|
||||
|
||||
@@ -2,22 +2,157 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableLambda,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_PROMPT,
|
||||
DOCUMENTS_KEY,
|
||||
INTERMEDIATE_STEPS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
validate_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
# --- LCEL Runnable chain --- #
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
||||
OUTPUT_KEY = "output"
|
||||
|
||||
|
||||
def create_refine_documents_chain(
|
||||
llm: LanguageModelLike,
|
||||
initial_prompt: BasePromptTemplate,
|
||||
refine_prompt: BasePromptTemplate,
|
||||
*,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
) -> Runnable:
|
||||
"""Create a chain that feeds documents to a model one at a time and updates the output.
|
||||
|
||||
Args:
|
||||
llm: Language model to use for responding.
|
||||
initial_prompt: The prompt to use on the first document. Must accept "context"
|
||||
as one of the input variables. The first document will be passed in as
|
||||
"context".
|
||||
refine_prompt: The prompt to use on all subsequent documents. Must accept
|
||||
"context" and "output" as input variables. A document will be passed in as
|
||||
"context" and the refined output up to this iteration will be passed in as
|
||||
"output.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
|
||||
Returns:
|
||||
An LCEL `Runnable` chain.
|
||||
|
||||
Expects a dictionary as input with a list of `Document`s being passed under
|
||||
the "context" key.
|
||||
|
||||
Returns a dictionary as output. The output dictionary contains two keys,
|
||||
"output" and "intermediate_steps". "output" contains the final output.
|
||||
"intermediate_steps" contains the list of intermediate output
|
||||
strings generated by the chain, in the order that they were generated.
|
||||
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents.refine import create_refine_documents_chain
|
||||
|
||||
initial_prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "Summarize this information: {context}"),
|
||||
])
|
||||
refine_prompt = ChatPromptTemplate.from_messages([
|
||||
("system", '''You are summarizing a long document one page at a time. \
|
||||
You have summarized part of the document. Given the next page, update your \
|
||||
summary. Respond with only the updated summary and no other text. \
|
||||
Here is your working summary:\n\n{output}.'''),
|
||||
("human", "Here is the next page:\n\n{context}")
|
||||
])
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
chain = create_refine_documents_chain(llm, initial_prompt, refine_prompt, llm,)
|
||||
# chain.invoke({"context": docs})
|
||||
""" # noqa: E501
|
||||
validate_prompt(initial_prompt, (DOCUMENTS_KEY,))
|
||||
validate_prompt(refine_prompt, (DOCUMENTS_KEY, OUTPUT_KEY))
|
||||
|
||||
format_doc: Runnable = RunnableLambda(cast(Callable, _get_and_format_doc)).bind(
|
||||
document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs -> answer string based on first doc
|
||||
initial_response = format_doc.pipe(
|
||||
initial_prompt, llm, StrOutputParser(), name="initial_response"
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs, current answer, intermediate_steps
|
||||
# -> updated answer based on next doc
|
||||
refine_response = format_doc.pipe(
|
||||
refine_prompt, llm, StrOutputParser(), name="refine_response"
|
||||
)
|
||||
|
||||
# Runnable: Update intermediates_steps based on last output, in parallel update
|
||||
# output.
|
||||
refine_step = RunnablePassthrough.assign(
|
||||
intermediate_steps=_update_intermediate_steps,
|
||||
output=refine_response,
|
||||
)
|
||||
|
||||
# Function that returns a sequence of refine_steps equal to len(docs) - 1.
|
||||
refine_loop = RunnableLambda(cast(Callable, _runnable_loop)).bind(
|
||||
step=refine_step, step_name="refine_step_{iteration}"
|
||||
)
|
||||
|
||||
# Runnable: Dict with many docs -> {"answer": "...", "intermediate_steps": [...]}
|
||||
return (
|
||||
RunnablePassthrough.assign(output=initial_response)
|
||||
.pipe(refine_loop)
|
||||
.pick([OUTPUT_KEY, INTERMEDIATE_STEPS_KEY])
|
||||
.with_name("refine_documents_chain")
|
||||
)
|
||||
|
||||
|
||||
# --- Helpers for LCEL Runnable chain --- #
|
||||
|
||||
|
||||
def _get_and_format_doc(inputs: dict, document_prompt: BasePromptTemplate) -> dict:
|
||||
intermediate_steps = inputs.pop(INTERMEDIATE_STEPS_KEY, [])
|
||||
doc = inputs[DOCUMENTS_KEY][len(intermediate_steps)]
|
||||
inputs[DOCUMENTS_KEY] = format_document(doc, document_prompt)
|
||||
return inputs
|
||||
|
||||
|
||||
def _runnable_loop(inputs: dict, step: Runnable, step_name: str) -> Runnable:
|
||||
if len(inputs[DOCUMENTS_KEY]) < 2:
|
||||
return RunnablePassthrough()
|
||||
chain: Runnable = step.with_name(step_name.format(iteration=1))
|
||||
for iteration in range(2, len(inputs[DOCUMENTS_KEY])):
|
||||
chain |= step.with_name(step_name.format(iteration=iteration))
|
||||
return chain
|
||||
|
||||
|
||||
def _update_intermediate_steps(inputs: dict) -> list:
|
||||
return inputs.get(INTERMEDIATE_STEPS_KEY, []) + [inputs[OUTPUT_KEY]]
|
||||
|
||||
|
||||
# --- Legacy Chain --- #
|
||||
|
||||
|
||||
class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
@@ -81,7 +216,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
initial_response_name: str
|
||||
"""The variable name to format the initial response in when refining."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
|
||||
)
|
||||
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.runnables import Runnable, RunnableLambda
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
@@ -14,7 +14,8 @@ from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_SEPARATOR,
|
||||
DOCUMENTS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
_validate_prompt,
|
||||
format_document_inputs,
|
||||
validate_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
@@ -27,13 +28,13 @@ def create_stuff_documents_chain(
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||
) -> Runnable[Dict[str, Any], Any]:
|
||||
"""Create a chain for passing a list of Documents to a model.
|
||||
"""Create a chain that passes a list of documents to a model.
|
||||
|
||||
Args:
|
||||
llm: Language model.
|
||||
llm: Language model to use for responding.
|
||||
prompt: Prompt template. Must contain input variable "context", which will be
|
||||
used for passing in the formatted documents.
|
||||
output_parser: Output parser. Defaults to StrOutputParser.
|
||||
output_parser: Output parser. Defaults to `StrOutputParser`.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
@@ -43,9 +44,12 @@ def create_stuff_documents_chain(
|
||||
document_separator: String separator to use between formatted document strings.
|
||||
|
||||
Returns:
|
||||
An LCEL Runnable. The input is a dictionary that must have a "context" key that
|
||||
maps to a List[Document], and any other input variables expected in the prompt.
|
||||
The Runnable return type depends on output_parser used.
|
||||
An LCEL `Runnable` chain .
|
||||
|
||||
Expects a dictionary as input with a list of `Document`s being passed under
|
||||
the "context" key.
|
||||
|
||||
Return type depends on the `output_parser` used.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -58,7 +62,10 @@ def create_stuff_documents_chain(
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "What are everyone's favorite colors:\n\n{context}")]
|
||||
[
|
||||
("system", "Use the following context to answer the user's questions:\n\n{context}"),
|
||||
("human", "{question}"),
|
||||
]
|
||||
)
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||
chain = create_stuff_documents_chain(llm, prompt)
|
||||
@@ -68,26 +75,20 @@ def create_stuff_documents_chain(
|
||||
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||
]
|
||||
|
||||
chain.invoke({"context": docs})
|
||||
chain.invoke({"context": docs, "question": "What are everyone's favorite colors?"})
|
||||
""" # noqa: E501
|
||||
|
||||
_validate_prompt(prompt)
|
||||
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
||||
validate_prompt(prompt, (DOCUMENTS_KEY,))
|
||||
_output_parser = output_parser or StrOutputParser()
|
||||
|
||||
def format_docs(inputs: dict) -> str:
|
||||
return document_separator.join(
|
||||
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
|
||||
)
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
|
||||
run_name="format_inputs"
|
||||
RunnableLambda(cast(Callable, format_document_inputs))
|
||||
.bind(
|
||||
document_prompt=document_prompt or DEFAULT_DOCUMENT_PROMPT,
|
||||
document_separator=document_separator,
|
||||
)
|
||||
| prompt
|
||||
| llm
|
||||
| _output_parser
|
||||
).with_config(run_name="stuff_documents_chain")
|
||||
.pipe(prompt, llm, _output_parser, name="stuff_documents_chain")
|
||||
)
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
Reference in New Issue
Block a user