mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
langchain[minor]: Add stuff docs runnable (#15178)
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
63916cfe35
commit
56fad2e8ff
@ -5,5 +5,11 @@ from langchain.chains.combine_documents.reduce import (
|
||||
collapse_docs,
|
||||
split_list_of_docs,
|
||||
)
|
||||
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
|
||||
|
||||
__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"]
|
||||
__all__ = [
|
||||
"acollapse_docs",
|
||||
"collapse_docs",
|
||||
"split_list_of_docs",
|
||||
"create_stuff_documents_chain",
|
||||
]
|
||||
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@ -14,6 +15,18 @@ from langchain.callbacks.manager import (
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
|
||||
DOCUMENTS_KEY = "context"
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
||||
|
||||
|
||||
def _validate_prompt(prompt: BasePromptTemplate) -> None:
|
||||
if DOCUMENTS_KEY not in prompt.input_variables:
|
||||
raise ValueError(
|
||||
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
|
||||
f"with input variables: {prompt.input_variables}"
|
||||
)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents.
|
||||
|
@ -1,21 +1,93 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import LanguageModelLike
|
||||
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chains.combine_documents.base import (
|
||||
DEFAULT_DOCUMENT_PROMPT,
|
||||
DEFAULT_DOCUMENT_SEPARATOR,
|
||||
DOCUMENTS_KEY,
|
||||
BaseCombineDocumentsChain,
|
||||
_validate_prompt,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
||||
def create_stuff_documents_chain(
|
||||
llm: LanguageModelLike,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||
) -> Runnable[Dict[str, Any], Any]:
|
||||
"""Create a chain for passing a list of Documents to a model.
|
||||
|
||||
Args:
|
||||
llm: Language model.
|
||||
prompt: Prompt template. Must contain input variable "context", which will be
|
||||
used for passing in the formatted documents.
|
||||
output_parser: Output parser. Defaults to StrOutputParser.
|
||||
document_prompt: Prompt used for formatting each document into a string. Input
|
||||
variables can be "page_content" or any metadata keys that are in all
|
||||
documents. "page_content" will automatically retrieve the
|
||||
`Document.page_content`, and all other inputs variables will be
|
||||
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||
a prompt that only contains `Document.page_content`.
|
||||
document_separator: String separator to use between formatted document strings.
|
||||
|
||||
Returns:
|
||||
An LCEL Runnable. The input is a dictionary that must have a "context" key that
|
||||
maps to a List[Document], and any other input variables expected in the prompt.
|
||||
The Runnable return type depends on output_parser used.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# pip install -U langchain langchain-community
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[("system", "What are everyone's favorite colors:\n\n{context}")]
|
||||
)
|
||||
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||
chain = create_stuff_documents_chain(llm, prompt)
|
||||
|
||||
docs = [
|
||||
Document(page_content="Jesse loves red but not yellow"),
|
||||
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||
]
|
||||
|
||||
chain.invoke({"context": docs})
|
||||
""" # noqa: E501
|
||||
|
||||
_validate_prompt(prompt)
|
||||
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
||||
_output_parser = output_parser or StrOutputParser()
|
||||
|
||||
def format_docs(inputs: dict) -> str:
|
||||
return document_separator.join(
|
||||
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
|
||||
)
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
|
||||
run_name="format_inputs"
|
||||
)
|
||||
| prompt
|
||||
| llm
|
||||
| _output_parser
|
||||
).with_config(run_name="stuff_documents_chain")
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
@ -60,7 +132,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""LLM chain which is called with the formatted document string,
|
||||
along with any other inputs."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
|
||||
)
|
||||
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||
document_variable_name: str
|
||||
|
Loading…
Reference in New Issue
Block a user