mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
langchain[minor]: Add stuff docs runnable (#15178)
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
63916cfe35
commit
56fad2e8ff
@ -5,5 +5,11 @@ from langchain.chains.combine_documents.reduce import (
|
|||||||
collapse_docs,
|
collapse_docs,
|
||||||
split_list_of_docs,
|
split_list_of_docs,
|
||||||
)
|
)
|
||||||
|
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain
|
||||||
|
|
||||||
__all__ = ["acollapse_docs", "collapse_docs", "split_list_of_docs"]
|
__all__ = [
|
||||||
|
"acollapse_docs",
|
||||||
|
"collapse_docs",
|
||||||
|
"split_list_of_docs",
|
||||||
|
"create_stuff_documents_chain",
|
||||||
|
]
|
||||||
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
@ -14,6 +15,18 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
|
||||||
|
DOCUMENTS_KEY = "context"
|
||||||
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_prompt(prompt: BasePromptTemplate) -> None:
|
||||||
|
if DOCUMENTS_KEY not in prompt.input_variables:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
|
||||||
|
f"with input variables: {prompt.input_variables}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseCombineDocumentsChain(Chain, ABC):
|
class BaseCombineDocumentsChain(Chain, ABC):
|
||||||
"""Base interface for chains combining documents.
|
"""Base interface for chains combining documents.
|
||||||
|
@ -1,21 +1,93 @@
|
|||||||
"""Chain that combines documents by stuffing into context."""
|
"""Chain that combines documents by stuffing into context."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.language_models import LanguageModelLike
|
||||||
|
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
|
||||||
from langchain_core.prompts import BasePromptTemplate, format_document
|
from langchain_core.prompts import BasePromptTemplate, format_document
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
||||||
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chains.combine_documents.base import (
|
from langchain.chains.combine_documents.base import (
|
||||||
|
DEFAULT_DOCUMENT_PROMPT,
|
||||||
|
DEFAULT_DOCUMENT_SEPARATOR,
|
||||||
|
DOCUMENTS_KEY,
|
||||||
BaseCombineDocumentsChain,
|
BaseCombineDocumentsChain,
|
||||||
|
_validate_prompt,
|
||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
def _get_default_document_prompt() -> PromptTemplate:
|
def create_stuff_documents_chain(
|
||||||
return PromptTemplate(input_variables=["page_content"], template="{page_content}")
|
llm: LanguageModelLike,
|
||||||
|
prompt: BasePromptTemplate,
|
||||||
|
*,
|
||||||
|
output_parser: Optional[BaseOutputParser] = None,
|
||||||
|
document_prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||||
|
) -> Runnable[Dict[str, Any], Any]:
|
||||||
|
"""Create a chain for passing a list of Documents to a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: Language model.
|
||||||
|
prompt: Prompt template. Must contain input variable "context", which will be
|
||||||
|
used for passing in the formatted documents.
|
||||||
|
output_parser: Output parser. Defaults to StrOutputParser.
|
||||||
|
document_prompt: Prompt used for formatting each document into a string. Input
|
||||||
|
variables can be "page_content" or any metadata keys that are in all
|
||||||
|
documents. "page_content" will automatically retrieve the
|
||||||
|
`Document.page_content`, and all other inputs variables will be
|
||||||
|
automatically retrieved from the `Document.metadata` dictionary. Default to
|
||||||
|
a prompt that only contains `Document.page_content`.
|
||||||
|
document_separator: String separator to use between formatted document strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An LCEL Runnable. The input is a dictionary that must have a "context" key that
|
||||||
|
maps to a List[Document], and any other input variables expected in the prompt.
|
||||||
|
The Runnable return type depends on output_parser used.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# pip install -U langchain langchain-community
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatOpenAI
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[("system", "What are everyone's favorite colors:\n\n{context}")]
|
||||||
|
)
|
||||||
|
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||||
|
chain = create_stuff_documents_chain(llm, prompt)
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
Document(page_content="Jesse loves red but not yellow"),
|
||||||
|
Document(page_content = "Jamal loves green but not as much as he loves orange")
|
||||||
|
]
|
||||||
|
|
||||||
|
chain.invoke({"context": docs})
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
_validate_prompt(prompt)
|
||||||
|
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
|
||||||
|
_output_parser = output_parser or StrOutputParser()
|
||||||
|
|
||||||
|
def format_docs(inputs: dict) -> str:
|
||||||
|
return document_separator.join(
|
||||||
|
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
|
||||||
|
run_name="format_inputs"
|
||||||
|
)
|
||||||
|
| prompt
|
||||||
|
| llm
|
||||||
|
| _output_parser
|
||||||
|
).with_config(run_name="stuff_documents_chain")
|
||||||
|
|
||||||
|
|
||||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||||
@ -60,7 +132,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""LLM chain which is called with the formatted document string,
|
"""LLM chain which is called with the formatted document string,
|
||||||
along with any other inputs."""
|
along with any other inputs."""
|
||||||
document_prompt: BasePromptTemplate = Field(
|
document_prompt: BasePromptTemplate = Field(
|
||||||
default_factory=_get_default_document_prompt
|
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
|
||||||
)
|
)
|
||||||
"""Prompt to use to format each document, gets passed to `format_document`."""
|
"""Prompt to use to format each document, gets passed to `format_document`."""
|
||||||
document_variable_name: str
|
document_variable_name: str
|
||||||
|
Loading…
Reference in New Issue
Block a user