feat(langchain): add stuff and map reduce chains (#32333)

* Add stuff and map reduce chains
* We'll need to rename and add unit tests to the chains prior to
official release
This commit is contained in:
Eugene Yurtsev 2025-08-07 15:20:05 -04:00 committed by GitHub
parent ac706c77d4
commit 754528d23f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1416 additions and 12 deletions

View File

@ -0,0 +1,39 @@
"""Internal document utilities."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.documents import Document
def format_document_xml(doc: Document) -> str:
"""Format a document as XML-like structure for LLM consumption.
Args:
doc: Document to format
Returns:
Document wrapped in XML tags:
<document>
<id>...</id>
<content>...</content>
<metadata>...</metadata>
</document>
Note:
Does not generate valid XML or escape special characters.
Intended for semi-structured LLM input only.
"""
id_str = f"<id>{doc.id}</id>" if doc.id is not None else "<id></id>"
metadata_str = ""
if doc.metadata:
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
return (
f"<document>{id_str}"
f"<content>{doc.page_content}</content>"
f"{metadata_str}"
f"</document>"
)

View File

@ -0,0 +1,36 @@
"""Lazy import utilities."""
from importlib import import_module
from typing import Union
def import_attr(
attr_name: str,
module_name: Union[str, None],
package: Union[str, None],
) -> object:
"""Import an attribute from a module located in a package.
This utility function is used in custom __getattr__ methods within __init__.py
files to dynamically import attributes.
Args:
attr_name: The name of the attribute to import.
module_name: The name of the module to import from. If None, the attribute
is imported from the package itself.
package: The name of the package where the module is located.
"""
if module_name == "__module__" or module_name is None:
try:
result = import_module(f".{attr_name}", package=package)
except ModuleNotFoundError:
msg = f"module '{package!r}' has no attribute {attr_name!r}"
raise AttributeError(msg) from None
else:
try:
module = import_module(f".{module_name}", package=package)
except ModuleNotFoundError as err:
msg = f"module '{package!r}.{module_name!r}' not found ({err})"
raise ImportError(msg) from None
result = getattr(module, attr_name)
return result

View File

@ -0,0 +1,166 @@
"""Internal prompt resolution utilities.
This module provides utilities for resolving different types of prompt specifications
into standardized message formats for language models. It supports both synchronous
and asynchronous prompt resolution with automatic detection of callable types.
The module is designed to handle common prompt patterns across LangChain components,
particularly for summarization chains and other document processing workflows.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Callable, Union
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.messages import MessageLikeRepresentation
from langgraph.runtime import Runtime
from langchain._internal._typing import ContextT, StateT
def resolve_prompt(
prompt: Union[
str,
None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
],
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Resolve a prompt specification into a list of messages.
Handles prompt resolution across different strategies. Supports callable functions,
string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable: Function taking (state, runtime) returning message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
def custom_prompt(state, runtime):
return [{"role": "system", "content": "Custom"}]
messages = resolve_prompt(custom_prompt, state, runtime, "content", "default")
messages = resolve_prompt("Custom system", state, runtime, "content", "default")
messages = resolve_prompt(None, state, runtime, "content", "Default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. String/None prompts create standard system + user
structure.
"""
if callable(prompt):
return prompt(state, runtime)
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]
async def aresolve_prompt(
prompt: Union[
str,
None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
Callable[
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
],
],
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Async version of resolve_prompt supporting both sync and async callables.
Handles prompt resolution across different strategies. Supports sync/async callable
functions, string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable (sync): Function taking (state, runtime) returning message list.
- Callable (async): Async function taking (state, runtime) returning
awaitable message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
async def async_prompt(state, runtime):
return [{"role": "system", "content": "Async"}]
def sync_prompt(state, runtime):
return [{"role": "system", "content": "Sync"}]
messages = await aresolve_prompt(
async_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt(
sync_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt("Custom", state, runtime, "content", "default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. Automatically detects and handles async
callables.
"""
if callable(prompt):
result = prompt(state, runtime)
# Check if the result is awaitable (async function)
if inspect.isawaitable(result):
return await result
return result
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]

View File

@ -0,0 +1,65 @@
"""Private typing utilities for langchain."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
from langgraph.graph._node import StateNode
from pydantic import BaseModel
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from dataclasses import Field
class TypedDictLikeV1(Protocol):
"""Protocol to represent types that behave like TypedDicts.
Version 1: using `ClassVar` for keys.
"""
__required_keys__: ClassVar[frozenset[str]]
__optional_keys__: ClassVar[frozenset[str]]
class TypedDictLikeV2(Protocol):
"""Protocol to represent types that behave like TypedDicts.
Version 2: not using `ClassVar` for keys.
"""
__required_keys__: frozenset[str]
__optional_keys__: frozenset[str]
class DataclassLike(Protocol):
"""Protocol to represent types that behave like dataclasses.
Inspired by the private _DataclassT from dataclasses that uses a similar
protocol as a bound.
"""
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
StateLike: TypeAlias = Union[TypedDictLikeV1, TypedDictLikeV2, DataclassLike, BaseModel]
"""Type alias for state-like types.
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in
type checking.
"""
StateT = TypeVar("StateT", bound=StateLike)
"""Type variable used to represent the state in a graph."""
ContextT = TypeVar("ContextT", bound=Union[StateLike, None])
"""Type variable for context types."""
__all__ = [
"ContextT",
"StateLike",
"StateNode",
"StateT",
]

View File

@ -0,0 +1,7 @@
# Re-exporting internal utilities from LangGraph for internal use in LangChain.
# A different wrapper needs to be created for this purpose in LangChain.
from langgraph._internal._runnable import RunnableCallable
__all__ = [
"RunnableCallable",
]

View File

@ -0,0 +1,9 @@
from langchain.chains.documents import (
create_map_reduce_chain,
create_stuff_documents_chain,
)
__all__ = [
"create_map_reduce_chain",
"create_stuff_documents_chain",
]

View File

@ -0,0 +1,17 @@
"""Document extraction chains.
This module provides different strategies for extracting information from collections
of documents using LangGraph and modern language models.
Available Strategies:
- Stuff: Processes all documents together in a single context window
- Map-Reduce: Processes documents in parallel (map), then combines results (reduce)
"""
from langchain.chains.documents.map_reduce import create_map_reduce_chain
from langchain.chains.documents.stuff import create_stuff_documents_chain
__all__ = [
"create_map_reduce_chain",
"create_stuff_documents_chain",
]

View File

@ -0,0 +1,586 @@
"""Map-Reduce Extraction Implementation using LangGraph Send API."""
from __future__ import annotations
import operator
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Generic,
Literal,
Optional,
Union,
cast,
)
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from typing_extensions import NotRequired, TypedDict
from langchain._internal._documents import format_document_xml
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
from langchain._internal._typing import ContextT, StateNode
from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
# Pycharm is unable to identify that AIMessage is used in the cast below
from langchain_core.messages import (
AIMessage,
MessageLikeRepresentation,
)
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
class ExtractionResult(TypedDict):
"""Result from processing a document or group of documents."""
indexes: list[int]
"""Document indexes that contributed to this result."""
result: Any
"""Extracted result from the document(s)."""
class MapReduceState(TypedDict):
"""State for map-reduce extraction chain.
This state tracks the map-reduce process where documents are processed
in parallel during the map phase, then combined in the reduce phase.
"""
documents: list[Document]
"""List of documents to process."""
map_results: Annotated[list[ExtractionResult], operator.add]
"""Individual results from the map phase."""
result: NotRequired[Any]
"""Final combined result from the reduce phase if applicable."""
# The payload for the map phase is a list of documents and their indexes.
# The current implementation only supports a single document per map operation,
# but the structure allows for future expansion to process a group of documents.
# A user would provide an input split function that returns groups of documents
# to process together, if desired.
class MapState(TypedDict):
"""State for individual map operations."""
documents: list[Document]
"""List of documents to process in map phase."""
indexes: list[int]
"""List of indexes of the documents in the original list."""
class InputSchema(TypedDict):
"""Input schema for the map-reduce extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
class OutputSchema(TypedDict):
"""Output schema for the map-reduce extraction chain.
Defines the format of the final result returned by the chain.
"""
map_results: list[ExtractionResult]
"""List of individual extraction results from the map phase."""
result: Any
"""Final combined result from all documents."""
class MapReduceNodeUpdate(TypedDict):
"""Update returned by map-reduce nodes."""
map_results: NotRequired[list[ExtractionResult]]
"""Updated results after map phase."""
result: NotRequired[Any]
"""Final result after reduce phase."""
class _MapReduceExtractor(Generic[ContextT]):
"""Map-reduce extraction implementation using LangGraph Send API.
This implementation uses a language model to process documents through up
to two phases:
1. **Map Phase**: Each document is processed independently by the LLM using
the configured map_prompt to generate individual extraction results.
2. **Reduce Phase (Optional)**: Individual results can optionally be
combined using either:
- The default LLM-based reducer with the configured reduce_prompt
- A custom reducer function (which can be non-LLM based)
- Skipped entirely by setting reduce=None
The map phase processes documents in parallel for efficiency, making this approach
well-suited for large document collections. The reduce phase is flexible and can be
customized or omitted based on your specific requirements.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[
[MapState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce_prompt: Union[
str,
None,
Callable[
[MapReduceState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the MapReduceExtractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
if (reduce is None or callable(reduce)) and reduce_prompt is not None:
msg = (
"reduce_prompt must be None when reduce is None or a custom "
"callable. Custom reduce functions handle their own logic and "
"should not use reduce_prompt."
)
raise ValueError(msg)
self.response_format = response_format
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.map_prompt = map_prompt
self.reduce_prompt = reduce_prompt
self.reduce = reduce
self.context_schema = context_schema
def _get_map_prompt(
self, state: MapState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents."""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return resolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_map_prompt(
self, state: MapState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents in the map phase.
Async version.
"""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return await aresolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
def _get_reduce_prompt(
self, state: MapReduceState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results.
Combines map results in the reduce phase.
"""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return resolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_reduce_prompt(
self, state: MapReduceState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results.
Async version of reduce phase.
"""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return await aresolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
def create_map_node(self) -> RunnableCallable:
"""Create a LangGraph node that processes individual documents using the LLM."""
def _map_node(
state: MapState, runtime: Runtime[ContextT], config: RunnableConfig
) -> dict[str, list[ExtractionResult]]:
prompt = self._get_map_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
async def _amap_node(
state: MapState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> dict[str, list[ExtractionResult]]:
prompt = await self._aget_map_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
return RunnableCallable(
_map_node,
_amap_node,
trace=False,
)
def create_reduce_node(self) -> RunnableCallable:
"""Create a LangGraph node that combines individual results using the LLM."""
def _reduce_node(
state: MapReduceState, runtime: Runtime[ContextT], config: RunnableConfig
) -> MapReduceNodeUpdate:
prompt = self._get_reduce_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _areduce_node(
state: MapReduceState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> MapReduceNodeUpdate:
prompt = await self._aget_reduce_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_reduce_node,
_areduce_node,
trace=False,
)
def continue_to_map(self, state: MapReduceState) -> list[Send]:
"""Generate Send objects for parallel map operations."""
return [
Send("map_process", {"documents": [doc], "indexes": [i]})
for i, doc in enumerate(state["documents"])
]
def build(
self,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for map-reduce summarization."""
builder = StateGraph(
MapReduceState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_node("map_process", self.create_map_node())
builder.add_edge(START, "continue_to_map")
# Add-conditional edges doesn't explicitly type Send
builder.add_conditional_edges(
"continue_to_map",
self.continue_to_map, # type: ignore[arg-type]
["map_process"],
)
if self.reduce is None:
builder.add_edge("map_process", END)
elif self.reduce == "default_reducer":
builder.add_node("reduce_process", self.create_reduce_node())
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
else:
reduce_node = cast("StateNode", self.reduce)
# The type is ignored here. Requires parameterizing with generics.
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
return builder
def create_map_reduce_chain(
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[[MapState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
reduce_prompt: Union[
str,
None,
Callable[[MapReduceState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Create a map-reduce document extraction chain.
This implementation uses a language model to extract information from documents
through a flexible approach that efficiently handles large document collections
by processing documents in parallel.
**Processing Flow:**
1. **Map Phase**: Each document is independently processed by the LLM
using the map_prompt to extract relevant information and generate
individual results.
2. **Reduce Phase (Optional)**: Individual extraction results can
optionally be combined using:
- The default LLM-based reducer with reduce_prompt (default behavior)
- A custom reducer function (can be non-LLM based)
- Skipped entirely by setting reduce=None
3. **Output**: Returns the individual map results and optionally the final
combined result.
Example:
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.documents import Document
>>>
>>> model = ChatAnthropic(
... model="claude-sonnet-4-20250514",
... temperature=0,
... max_tokens=62_000,
... timeout=None,
... max_retries=2,
... )
>>> builder = create_map_reduce_chain(model)
>>> chain = builder.compile()
>>> docs = [
... Document(page_content="First document content..."),
... Document(page_content="Second document content..."),
... Document(page_content="Third document content..."),
... ]
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with string model:
>>> builder = create_map_reduce_chain("anthropic:claude-sonnet-4-20250514")
>>> chain = builder.compile()
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with structured output:
```python
from pydantic import BaseModel
class ExtractionModel(BaseModel):
title: str
key_points: list[str]
conclusion: str
builder = create_map_reduce_chain(
model,
response_format=ExtractionModel
)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Example skipping the reduce phase:
```python
# Only perform map phase, skip combining results
builder = create_map_reduce_chain(model, reduce=None)
chain = builder.compile()
result = chain.invoke({"documents": docs})
# result["result"] will be None, only map_results are available
for map_result in result["map_results"]:
print(f"Document {map_result['indexes'][0]}: {map_result['result']}")
```
Example with custom reducer:
```python
def custom_aggregator(state, runtime):
# Custom non-LLM based reduction logic
map_results = state["map_results"]
combined_text = " | ".join(r["result"] for r in map_results)
word_count = len(combined_text.split())
return {
"result": f"Combined {len(map_results)} results with "
f"{word_count} total words"
}
builder = create_map_reduce_chain(model, reduce=custom_aggregator)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"]) # Custom aggregated result
```
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to get map-reduce
extraction results.
Note:
This implementation is well-suited for large document collections as it
processes documents in parallel during the map phase. The Send API enables
efficient parallelization while maintaining clean state management.
"""
extractor = _MapReduceExtractor(
model,
map_prompt=map_prompt,
reduce_prompt=reduce_prompt,
reduce=reduce,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_map_reduce_chain"]

View File

@ -0,0 +1,473 @@
"""Stuff documents chain for processing documents by putting them all in context."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Union,
cast,
)
# Used not only for type checking, but is fetched at runtime by Pydantic.
from langchain_core.documents import Document as Document # noqa: TC002
from langgraph.graph import START, StateGraph
from typing_extensions import NotRequired, TypedDict
from langchain._internal._documents import format_document_xml
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
from langchain._internal._typing import ContextT
from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
# Used for type checking, but IDEs may not recognize it inside the cast.
from langchain_core.messages import AIMessage as AIMessage
from langchain_core.messages import MessageLikeRepresentation
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
# Default system prompts
DEFAULT_INIT_PROMPT = (
"You are a helpful assistant that summarizes text. "
"Please provide a concise summary of the documents "
"provided by the user."
)
DEFAULT_STRUCTURED_INIT_PROMPT = (
"You are a helpful assistant that extracts structured information from documents. "
"Use the provided content and optional question to generate your output, formatted "
"according to the predefined schema."
)
DEFAULT_REFINE_PROMPT = (
"You are a helpful assistant that refines summaries. "
"Given an existing summary and new context, produce a refined summary "
"that incorporates the new information while maintaining conciseness."
)
DEFAULT_STRUCTURED_REFINE_PROMPT = (
"You are a helpful assistant refining structured information extracted "
"from documents. "
"You are given a previous result and new document context. "
"Update the output to reflect the new context, staying consistent with "
"the expected schema."
)
def _format_documents_content(documents: list[Document]) -> str:
"""Format documents into content string.
Args:
documents: List of documents to format.
Returns:
Formatted document content string.
"""
return "\n\n".join(format_document_xml(doc) for doc in documents)
class ExtractionState(TypedDict):
"""State for extraction chain.
This state tracks the extraction process where documents
are processed in batch, with the result being refined if needed.
"""
documents: list[Document]
"""List of documents to process."""
result: NotRequired[Any]
"""Current result, refined with each document."""
class InputSchema(TypedDict):
"""Input schema for the extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
result: NotRequired[Any]
"""Existing result to refine (optional)."""
class OutputSchema(TypedDict):
"""Output schema for the extraction chain.
Defines the format of the final result returned by the chain.
"""
result: Any
"""Result from processing the documents."""
class ExtractionNodeUpdate(TypedDict):
"""Update returned by processing nodes."""
result: NotRequired[Any]
"""Updated result after processing a document."""
class _Extractor(Generic[ContextT]):
"""Stuff documents chain implementation.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided.
Important: This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
refine_prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the Extractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
prompt: Prompt for initial processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
self.response_format = response_format
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.initial_prompt = prompt
self.refine_prompt = refine_prompt
self.context_schema = context_schema
def _get_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt."""
user_content = _format_documents_content(state["documents"])
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return resolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt (async version)."""
user_content = _format_documents_content(state["documents"])
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return await aresolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
def _get_refine_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return resolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_refine_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt (async version)."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return await aresolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
def create_document_processor_node(self) -> RunnableCallable:
"""Create the main document processing node.
The node handles both initial processing and refinement of results.
Refinement is done by providing the existing result and new context.
If the workflow is run with a checkpointer enabled, the result will be
persisted and available for a given thread id.
"""
def _process_node(
state: ExtractionState, runtime: Runtime[ContextT], config: RunnableConfig
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = self._get_initial_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = self._get_refine_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _aprocess_node(
state: ExtractionState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = await self._aget_initial_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = await self._aget_refine_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_process_node,
_aprocess_node,
trace=False,
)
def build(
self,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for batch document extraction."""
builder = StateGraph(
ExtractionState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_edge(START, "process")
builder.add_node("process", self.create_document_processor_node())
return builder
def create_stuff_documents_chain(
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
refine_prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Create a stuff documents chain for processing documents.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided. The default prompts are optimized for summarization tasks, but
can be customized for other extraction tasks via the prompt parameters or
response_format.
Strategy:
1. Put all documents into the context window
2. Process all documents together in a single request
3. If an existing result is provided, refine it with all documents at once
4. Return the result
Important:
This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
Example:
```python
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
model = init_chat_model("anthropic:claude-sonnet-4-20250514")
builder = create_stuff_documents_chain(model)
chain = builder.compile()
docs = [
Document(page_content="First document content..."),
Document(page_content="Second document content..."),
Document(page_content="Third document content..."),
]
result = chain.invoke({"documents": docs})
print(result["result"])
# Structured summary/extraction by passing a schema
from pydantic import BaseModel
class Summary(BaseModel):
title: str
key_points: list[str]
builder = create_stuff_documents_chain(model, response_format=Summary)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Args:
model: The language model for document processing.
prompt: Prompt for initial processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to extract information.
Note:
This is a "stuff" documents chain that puts all documents into the context
window and processes them together. It supports refining existing results.
Default prompts are optimized for summarization but can be customized for
other tasks. Important: Does not control for context window size.
"""
extractor = _Extractor(
model,
prompt=prompt,
refine_prompt=refine_prompt,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_stuff_documents_chain"]

View File

@ -0,0 +1,5 @@
from langchain_core.documents import Document
__all__ = [
"Document",
]

View File

@ -9,7 +9,7 @@ requires-python = ">=3.9, <4.0"
dependencies = [
"langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8",
"langgraph>=0.5.4",
"langgraph>=0.6.0",
"pydantic>=2.7.4",
]
@ -123,6 +123,7 @@ ignore = [
"UP007", # pyupgrade: non-pep604-annotation-union
"PLC0415", # Imports should be at the top. Not always desirable
"PLR0913", # Too many arguments in function definition
"PLC0414", # Inconsistent with how type checkers expect to be notified of intentional re-exports
]
unfixable = ["B028"] # People should intentionally tune the stacklevel

View File

@ -1667,7 +1667,7 @@ requires-dist = [
{ name = "langchain-text-splitters", editable = "../text-splitters" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langgraph", specifier = ">=0.5.4" },
{ name = "langgraph", specifier = ">=0.6.0" },
{ name = "pydantic", specifier = ">=2.7.4" },
]
provides-extras = ["anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
@ -1757,7 +1757,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.70"
version = "0.3.72"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },
@ -2133,7 +2133,7 @@ wheels = [
[[package]]
name = "langgraph"
version = "0.5.4"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
@ -2143,9 +2143,9 @@ dependencies = [
{ name = "pydantic" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/26/f01ae40ea26f8c723b6ec186869c80cc04de801630d99943018428b46105/langgraph-0.5.4.tar.gz", hash = "sha256:ab8f6b7b9c50fd2ae35a2efb072fbbfe79500dfc18071ac4ba6f5de5fa181931", size = 443149, upload-time = "2025-07-21T18:20:55.63Z" }
sdist = { url = "https://files.pythonhosted.org/packages/6e/12/4a30f766de571bfc319e70c2c0f4d050c0576e15c2249dc75ad122706a5d/langgraph-0.6.1.tar.gz", hash = "sha256:e4399ac5ad0b70f58fa28d6fe05a41b84c15959f270d6d1a86edab4e92ae148b", size = 449723, upload-time = "2025-07-29T20:45:28.438Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0d/82/15184e953234877107bad182b79c9111cb6ce6a79a97fdf36ebcaa11c0d0/langgraph-0.5.4-py3-none-any.whl", hash = "sha256:7122840225623e081be24ac30a691a24e5dac4c0361f593208f912838192d7f6", size = 143942, upload-time = "2025-07-21T18:20:54.442Z" },
{ url = "https://files.pythonhosted.org/packages/6e/90/e37e2bc19bee81f859bfd61526d977f54ec5d8e036fa091f2cfe4f19560b/langgraph-0.6.1-py3-none-any.whl", hash = "sha256:2736027faeb6cd5c0f1ab51a5345594cfcb5eb5beeb5ac1799a58fcecf4b4eae", size = 151874, upload-time = "2025-07-29T20:45:26.998Z" },
]
[[package]]
@ -2163,28 +2163,28 @@ wheels = [
[[package]]
name = "langgraph-prebuilt"
version = "0.5.2"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/11/98134c47832fbde0caf0e06f1a104577da9215c358d7854093c1d835b272/langgraph_prebuilt-0.5.2.tar.gz", hash = "sha256:2c900a5be0d6a93ea2521e0d931697cad2b646f1fcda7aa5c39d8d7539772465", size = 117808, upload-time = "2025-06-30T19:52:48.307Z" }
sdist = { url = "https://files.pythonhosted.org/packages/55/b6/d4f8e800bdfdd75486595203d5c622bba5f1098e4fd4220452c75568f2be/langgraph_prebuilt-0.6.1.tar.gz", hash = "sha256:574c409113e02d3c58157877c5ea638faa80647b259027647ab88830d7ecef00", size = 125057, upload-time = "2025-07-29T20:44:48.634Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/64/6bc45ab9e0e1112698ebff579fe21f5606ea65cd08266995a357e312a4d2/langgraph_prebuilt-0.5.2-py3-none-any.whl", hash = "sha256:1f4cd55deca49dffc3e5127eec12fcd244fc381321002f728afa88642d5ec59d", size = 23776, upload-time = "2025-06-30T19:52:47.494Z" },
{ url = "https://files.pythonhosted.org/packages/a6/df/cb4d73e99719b7ca0d42503d39dec49c67225779c6a1e689954a5604dbe6/langgraph_prebuilt-0.6.1-py3-none-any.whl", hash = "sha256:a3a970451371ec66509c6969505286a5d92132af7062d0b2b6dab08c2e27b50f", size = 28866, upload-time = "2025-07-29T20:44:47.72Z" },
]
[[package]]
name = "langgraph-sdk"
version = "0.1.74"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6d/f7/3807b72988f7eef5e0eb41e7e695eca50f3ed31f7cab5602db3b651c85ff/langgraph_sdk-0.1.74.tar.gz", hash = "sha256:7450e0db5b226cc2e5328ca22c5968725873630ef47c4206a30707cb25dc3ad6", size = 72190, upload-time = "2025-07-21T16:36:50.032Z" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/3e/3dc45dc7682c9940db9edaf8773d2e157397c5bd6881f6806808afd8731e/langgraph_sdk-0.2.0.tar.gz", hash = "sha256:cd8b5f6595e5571be5cbffd04cf936978ab8f5d1005517c99715947ef871e246", size = 72510, upload-time = "2025-07-22T17:31:06.745Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/1a/3eacc4df8127781ee4b0b1e5cad7dbaf12510f58c42cbcb9d1e2dba2a164/langgraph_sdk-0.1.74-py3-none-any.whl", hash = "sha256:3a265c3757fe0048adad4391d10486db63ef7aa5a2cbd22da22d4503554cb890", size = 50254, upload-time = "2025-07-21T16:36:49.134Z" },
{ url = "https://files.pythonhosted.org/packages/a5/03/a8ab0e8ea74be6058cb48bb1d85485b5c65d6ea183e3ee1aa8ca1ac73b3e/langgraph_sdk-0.2.0-py3-none-any.whl", hash = "sha256:150722264f225c4d47bbe7394676be102fdbf04c4400a0dd1bd41a70c6430cc7", size = 50569, upload-time = "2025-07-22T17:31:04.582Z" },
]
[[package]]