langchain[minor]: Add stuff docs runnable (#15178)

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Bagatur 2023-12-26 15:20:00 -05:00 committed by GitHub
parent 63916cfe35
commit 56fad2e8ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 6 deletions

View File

@ -5,5 +5,11 @@ from langchain.chains.combine_documents.reduce import (
collapse_docs, collapse_docs,
split_list_of_docs, split_list_of_docs,
) )
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"] __all__ = [
"acollapse_docs",
"collapse_docs",
"split_list_of_docs",
"create_stuff_documents_chain",
]

View File

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
@ -14,6 +15,18 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
def _validate_prompt(prompt: BasePromptTemplate) -> None:
if DOCUMENTS_KEY not in prompt.input_variables:
raise ValueError(
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
f"with input variables: {prompt.input_variables}"
)
class BaseCombineDocumentsChain(Chain, ABC): class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents. """Base interface for chains combining documents.

View File

@ -1,21 +1,93 @@
"""Chain that combines documents by stuffing into context.""" """Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import ( from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain, BaseCombineDocumentsChain,
_validate_prompt,
) )
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
def _get_default_document_prompt() -> PromptTemplate: def create_stuff_documents_chain(
return PromptTemplate(input_variables=["page_content"], template="{page_content}") llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model.
Args:
llm: Language model.
prompt: Prompt template. Must contain input variable "context", which will be
used for passing in the formatted documents.
output_parser: Output parser. Defaults to StrOutputParser.
document_prompt: Prompt used for formatting each document into a string. Input
variables can be "page_content" or any metadata keys that are in all
documents. "page_content" will automatically retrieve the
`Document.page_content`, and all other inputs variables will be
automatically retrieved from the `Document.metadata` dictionary. Default to
a prompt that only contains `Document.page_content`.
document_separator: String separator to use between formatted document strings.
Returns:
An LCEL Runnable. The input is a dictionary that must have a "context" key that
maps to a List[Document], and any other input variables expected in the prompt.
The Runnable return type depends on output_parser used.
Example:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\n\n{context}")]
)
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs})
""" # noqa: E501
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
class StuffDocumentsChain(BaseCombineDocumentsChain): class StuffDocumentsChain(BaseCombineDocumentsChain):
@ -60,7 +132,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"""LLM chain which is called with the formatted document string, """LLM chain which is called with the formatted document string,
along with any other inputs.""" along with any other inputs."""
document_prompt: BasePromptTemplate = Field( document_prompt: BasePromptTemplate = Field(
default_factory=_get_default_document_prompt default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
) )
"""Prompt to use to format each document, gets passed to `format_document`.""" """Prompt to use to format each document, gets passed to `format_document`."""
document_variable_name: str document_variable_name: str