Scaffolding

This commit is contained in:
Eugene Yurtsev 2025-07-30 15:26:41 -04:00
parent f624ad489a
commit 5e6c1e8380
12 changed files with 1400 additions and 12 deletions

View File

@ -0,0 +1,28 @@
"""Internal document utilities."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.documents import Document
def format_document_xml(doc: Document) -> str:
"""Format a single document with XML-style structure.
Args:
doc: Document to format
Returns:
Formatted document string with XML tags for id, content, and metadata
"""
id_str = f"<id>{doc.id}</id>" if doc.id is not None else "<id></id>"
metadata_str = ""
if doc.metadata:
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
return (
f"<document>{id_str}<content>{doc.page_content}</content>"
f"{metadata_str}</document>"
)

View File

@ -0,0 +1,36 @@
"""Lazy import utilities."""
from importlib import import_module
from typing import Union
def import_attr(
attr_name: str,
module_name: Union[str, None],
package: Union[str, None],
) -> object:
"""Import an attribute from a module located in a package.
This utility function is used in custom __getattr__ methods within __init__.py
files to dynamically import attributes.
Args:
attr_name: The name of the attribute to import.
module_name: The name of the module to import from. If None, the attribute
is imported from the package itself.
package: The name of the package where the module is located.
"""
if module_name == "__module__" or module_name is None:
try:
result = import_module(f".{attr_name}", package=package)
except ModuleNotFoundError:
msg = f"module '{package!r}' has no attribute {attr_name!r}"
raise AttributeError(msg) from None
else:
try:
module = import_module(f".{module_name}", package=package)
except ModuleNotFoundError as err:
msg = f"module '{package!r}.{module_name!r}' not found ({err})"
raise ImportError(msg) from None
result = getattr(module, attr_name)
return result

View File

@ -0,0 +1,163 @@
"""Internal prompt resolution utilities.
This module provides utilities for resolving different types of prompt specifications
into standardized message formats for language models. It supports both synchronous
and asynchronous prompt resolution with automatic detection of callable types.
The module is designed to handle common prompt patterns across LangChain components,
particularly for summarization chains and other document processing workflows.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Callable, Union
from typing_extensions import TypedDict
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.messages import MessageLikeRepresentation
from langgraph.runtime import Runtime
from langgraph.typing import StateLike
def resolve_prompt(
prompt: Union[
str, None, Callable[[TypedDict, Runtime], list[MessageLikeRepresentation]]
],
state: StateLike,
runtime: Runtime,
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Resolve a prompt specification into a list of messages.
Handles prompt resolution across different strategies. Supports callable functions,
string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable: Function taking (state, runtime) returning message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
def custom_prompt(state, runtime):
return [{"role": "system", "content": "Custom"}]
messages = resolve_prompt(custom_prompt, state, runtime, "content", "default")
messages = resolve_prompt("Custom system", state, runtime, "content", "default")
messages = resolve_prompt(None, state, runtime, "content", "Default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. String/None prompts create standard system + user
structure.
"""
if callable(prompt):
return prompt(state, runtime)
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]
async def aresolve_prompt(
prompt: Union[
str,
None,
Callable[[TypedDict, Runtime], list[MessageLikeRepresentation]],
Callable[[TypedDict, Runtime], Awaitable[list[MessageLikeRepresentation]]],
],
state: StateLike,
runtime: Runtime,
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Async version of resolve_prompt supporting both sync and async callables.
Handles prompt resolution across different strategies. Supports sync/async callable
functions, string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable (sync): Function taking (state, runtime) returning message list.
- Callable (async): Async function taking (state, runtime) returning
awaitable message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
async def async_prompt(state, runtime):
return [{"role": "system", "content": "Async"}]
def sync_prompt(state, runtime):
return [{"role": "system", "content": "Sync"}]
messages = await aresolve_prompt(
async_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt(
sync_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt("Custom", state, runtime, "content", "default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. Automatically detects and handles async
callables.
"""
if callable(prompt):
result = prompt(state, runtime)
# Check if the result is awaitable (async function)
if inspect.isawaitable(result):
return await result
return result
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]

View File

@ -0,0 +1,8 @@
# Re-exporting internal utilities from LangGraph for internal use in LangChain.
# TODO: We need to revisit the solution. Perhaps we expose a simple wrapper in langgraph
# create_node(sync, async) that will return a new node or something like that
from langgraph._internal._runnable import RunnableCallable
__all__ = [
"RunnableCallable",
]

View File

@ -0,0 +1,4 @@
from langchain.chains.documents import (
create_map_reduce_chain,
create_stuff_documents_chain,
)

View File

@ -0,0 +1,18 @@
"""Document extraction chains.
This module provides different strategies for extracting information from collections
of documents using LangGraph and modern language models.
Available Strategies:
- Iterative: Processes documents sequentially, refining results at each step
- Map-Reduce: Processes documents in parallel, then combines results
- Recursive: Hierarchical processing for documents
"""
from langchain.chains.documents.map_reduce import create_map_reduce_chain
from langchain.chains.documents.stuff import create_stuff_documents_chain
__al__ = [
"create_iterative_extractor",
"create_map_reduce_extractor",
]

View File

@ -0,0 +1,589 @@
"""Map-Reduce Extraction Implementation using LangGraph Send API."""
from __future__ import annotations
import operator
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Generic,
Literal,
NotRequired,
Optional,
Union,
cast,
)
# Needs to be in global scope as the type annotation is used at runtime
from langchain_core.documents import Document as Document
# Needs to be in global scope as the type annotation is used at runtime
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from langgraph.typing import ContextT
from typing_extensions import TypedDict
from langchain._internal.documents import format_document_xml
from langchain._internal.prompts import aresolve_prompt, resolve_prompt
from langchain._internal.utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import MessageLikeRepresentation
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
class ExtractionResult(TypedDict):
"""Result from processing a document or group of documents."""
indexes: list[int]
"""Document indexes that contributed to this result."""
result: Any
"""Extracted result from the document(s)."""
class MapReduceState(TypedDict):
"""State for map-reduce extraction chain.
This state tracks the map-reduce process where documents are processed
in parallel during the map phase, then combined in the reduce phase.
"""
documents: list[Document]
"""List of documents to process."""
map_results: Annotated[list[ExtractionResult], operator.add]
"""Individual results from the map phase."""
result: NotRequired[Any]
"""Final combined result from the reduce phase if applicable."""
class MapState(TypedDict):
"""State for individual map operations."""
documents: list[Document]
"""List of documents to process in map phase."""
indexes: list[int]
"""List of indexes of the documents in the original list."""
class InputSchema(TypedDict):
"""Input schema for the map-reduce extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
class OutputSchema(TypedDict):
"""Output schema for the map-reduce extraction chain.
Defines the format of the final result returned by the chain.
"""
map_results: list[ExtractionResult]
"""List of individual extraction results from the map phase."""
result: Any
"""Final combined result from all documents."""
class MapReduceNodeUpdate(TypedDict):
"""Update returned by map-reduce nodes."""
map_results: NotRequired[list[ExtractionResult]]
"""Updated results after map phase."""
result: NotRequired[Any]
"""Final result after reduce phase."""
class _MapReduceExtractor(Generic[ContextT]):
"""Map-reduce extraction implementation using LangGraph Send API.
This implementation uses a language model to process documents through up to two phases:
1. **Map Phase**: Each document is processed independently by the LLM using
the configured map_prompt to generate individual extraction results.
2. **Reduce Phase (Optional)**: Individual results can optionally be
combined using either:
- The default LLM-based reducer with the configured reduce_prompt
- A custom reducer function (which can be non-LLM based)
- Skipped entirely by setting reduce=None
The map phase processes documents in parallel for efficiency, making this approach
well-suited for large document collections. The reduce phase is flexible and can be
customized or omitted based on your specific requirements.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[
[MapState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce_prompt: Union[
str,
None,
Callable[
[MapReduceState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
Callable[
[MapReduceState, Runtime[ContextT]],
Any,
],
Callable[
[MapReduceState, Runtime[ContextT]],
Awaitable[Any],
],
] = "default_reducer",
context_schema: ContextT = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the MapReduceExtractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
if reduce is None or callable(reduce):
if reduce_prompt is not None:
msg = (
"reduce_prompt must be None when reduce is None or a custom "
"callable. Custom reduce functions handle their own logic and "
"should not use reduce_prompt."
)
raise ValueError(msg)
self.response_format = response_format
if isinstance(model, str):
model = cast("BaseChatModel", init_chat_model(model))
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.map_prompt = map_prompt
self.reduce_prompt = reduce_prompt
self.reduce = reduce
self.context_schema = context_schema or {}
def _get_map_prompt(
self, state: MapState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents in the map
phase."""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return resolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_map_prompt(
self, state: MapState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents in the map
phase (async)."""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return await aresolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
def _get_reduce_prompt(
self, state: MapReduceState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results in the
reduce phase."""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return resolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_reduce_prompt(
self, state: MapReduceState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results in the
reduce phase (async)."""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return await aresolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
def create_map_node(self) -> RunnableCallable:
"""Create a LangGraph node that processes individual documents using the LLM."""
def _map_node(
state: MapState, runtime: Runtime, config: RunnableConfig
) -> dict[str, list[ExtractionResult]]:
prompt = self._get_map_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
async def _amap_node(
state: MapState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> dict[str, list[ExtractionResult]]:
prompt = await self._aget_map_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
return RunnableCallable(
_map_node,
_amap_node,
trace=False,
)
def create_reduce_node(self) -> RunnableCallable:
"""Create a LangGraph node that combines individual results using the LLM."""
def _reduce_node(
state: MapReduceState, runtime: Runtime, config: RunnableConfig
) -> MapReduceNodeUpdate:
prompt = self._get_reduce_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _areduce_node(
state: MapReduceState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> MapReduceNodeUpdate:
prompt = await self._aget_reduce_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_reduce_node,
_areduce_node,
trace=False,
)
def continue_to_map(self, state: MapReduceState) -> list[Send]:
"""Generate Send objects for parallel map operations."""
return [
Send("map_process", {"documents": [doc], "indexes": [i]})
for i, doc in enumerate(state["documents"])
]
def build(
self,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for map-reduce summarization."""
builder = StateGraph(
MapReduceState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_node("continue_to_map", lambda state: {})
builder.add_node("map_process", self.create_map_node())
builder.add_edge(START, "continue_to_map")
builder.add_conditional_edges(
"continue_to_map", self.continue_to_map, ["map_process"]
)
if self.reduce is None:
builder.add_edge("map_process", END)
elif self.reduce == "default_reducer":
builder.add_node("reduce_process", self.create_reduce_node())
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
elif callable(self.reduce):
builder.add_node("reduce_process", self.reduce)
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
else:
msg = f"Invalid reduce configuration: {self.reduce}"
raise ValueError(msg)
return builder
def create_map_reduce_chain(
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[[MapState, Runtime], list[MessageLikeRepresentation]],
] = None,
reduce_prompt: Union[
str,
None,
Callable[[MapReduceState, Runtime], list[MessageLikeRepresentation]],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
Callable[
[MapReduceState, Runtime[ContextT]],
Any,
],
Callable[
[MapReduceState, Runtime[ContextT]],
Awaitable[Any],
],
] = "default_reducer",
context_schema: ContextT = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Create a map-reduce document extraction chain.
This implementation uses a language model to extract information from documents
through a flexible approach that efficiently handles large document collections
by processing documents in parallel.
**Processing Flow:**
1. **Map Phase**: Each document is independently processed by the LLM
using the map_prompt to extract relevant information and generate
individual results.
2. **Reduce Phase (Optional)**: Individual extraction results can
optionally be combined using:
- The default LLM-based reducer with reduce_prompt (default behavior)
- A custom reducer function (can be non-LLM based)
- Skipped entirely by setting reduce=None
3. **Output**: Returns the individual map results and optionally the final combined result.
Example:
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.documents import Document
>>>
>>> model = ChatAnthropic(
... model="claude-sonnet-4-20250514",
... temperature=0,
... max_tokens=62_000,
... timeout=None,
... max_retries=2,
... )
>>> builder = create_map_reduce_chain(model)
>>> chain = builder.compile()
>>> docs = [
... Document(page_content="First document content..."),
... Document(page_content="Second document content..."),
... Document(page_content="Third document content..."),
... ]
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with string model:
>>> builder = create_map_reduce_chain("anthropic:claude-sonnet-4-20250514")
>>> chain = builder.compile()
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with structured output:
```python
from pydantic import BaseModel
class ExtractionModel(BaseModel):
title: str
key_points: list[str]
conclusion: str
builder = create_map_reduce_chain(
model,
response_format=ExtractionModel
)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Example skipping the reduce phase:
```python
# Only perform map phase, skip combining results
builder = create_map_reduce_chain(model, reduce=None)
chain = builder.compile()
result = chain.invoke({"documents": docs})
# result["result"] will be None, only map_results are available
for map_result in result["map_results"]:
print(f"Document {map_result['indexes'][0]}: {map_result['result']}")
```
Example with custom reducer:
```python
def custom_aggregator(state, runtime):
# Custom non-LLM based reduction logic
map_results = state["map_results"]
combined_text = " | ".join(r["result"] for r in map_results)
word_count = len(combined_text.split())
return {
"result": f"Combined {len(map_results)} results with "
f"{word_count} total words"
}
builder = create_map_reduce_chain(model, reduce=custom_aggregator)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"]) # Custom aggregated result
```
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to get map-reduce
extraction results.
Note:
This implementation is well-suited for large document collections as it
processes documents in parallel during the map phase. The Send API enables
efficient parallelization while maintaining clean state management.
"""
extractor = _MapReduceExtractor(
model,
map_prompt=map_prompt,
reduce_prompt=reduce_prompt,
reduce=reduce,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_map_reduce_chain"]

View File

@ -0,0 +1,537 @@
"""Stuff documents chain for processing documents by putting them all in context."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
NotRequired,
Optional,
Union,
cast,
)
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from langgraph.typing import ContextT
from typing_extensions import TypedDict
from langchain._internal.documents import format_document_xml
from langchain._internal.prompts import aresolve_prompt, resolve_prompt
from langchain._internal.utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import MessageLikeRepresentation
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
# Default system prompts
DEFAULT_INIT_PROMPT = (
"You are a helpful assistant that summarizes text. "
"Please provide a concise summary of the documents "
"provided by the user."
)
DEFAULT_STRUCTURED_INIT_PROMPT = (
"You are a helpful assistant that extracts structured information from documents. "
"Use the provided content and optional question to generate your output, formatted "
"according to the predefined schema."
)
DEFAULT_REFINE_PROMPT = (
"You are a helpful assistant that refines summaries. "
"Given an existing summary and new context, produce a refined summary "
"that incorporates the new information while maintaining conciseness."
)
DEFAULT_STRUCTURED_REFINE_PROMPT = (
"You are a helpful assistant refining structured information extracted "
"from documents. "
"You are given a previous result and new document context. "
"Update the output to reflect the new context, staying consistent with "
"the expected schema."
)
def _format_documents_content(documents: list[Document]) -> str:
"""Format documents into content string.
Args:
documents: List of documents to format.
Returns:
Formatted document content string.
"""
return "\n\n".join(format_document_xml(doc) for doc in documents)
class ExtractionState(TypedDict):
"""State for extraction chain.
This state tracks the extraction process where documents
are processed in batch, with the result being refined if needed.
"""
documents: list[Document]
"""List of documents to process."""
question: NotRequired[str]
"""Optional question to ask about the documents."""
result: NotRequired[Any]
"""Current result, refined with each document."""
class InputSchema(TypedDict):
"""Input schema for the extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
question: NotRequired[str]
"""Optional question to ask about the documents."""
result: NotRequired[Any]
"""Existing result to refine (optional)."""
class OutputSchema(TypedDict):
"""Output schema for the extraction chain.
Defines the format of the final result returned by the chain.
"""
result: Any
"""Result from processing the documents."""
class ExtractionNodeUpdate(TypedDict):
"""Update returned by processing nodes."""
result: NotRequired[Any]
"""Updated result after processing a document."""
class _Extractor(Generic[ContextT]):
"""Stuff documents chain implementation.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided.
Important: This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
refine_prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
context_schema: ContextT = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the Extractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
prompt: Prompt for initial processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
self.response_format = response_format
if isinstance(model, str):
model = cast("BaseChatModel", init_chat_model(model))
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.initial_prompt = prompt
self.refine_prompt = refine_prompt
self.context_schema = context_schema or {}
def _get_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt."""
# Validate prompt/question combination
question = state.get("question")
if question and self.initial_prompt is not None and not callable(self.initial_prompt):
msg = (
"When a question is provided, the prompt must be None or a callable. "
"String prompts cannot be used with questions since the question "
"becomes the system prompt."
)
raise ValueError(msg)
user_content = _format_documents_content(state["documents"])
# Use question as default system prompt if provided, otherwise use standard defaults
if question:
default_prompt = question
else:
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return resolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt (async version)."""
# Validate prompt/question combination
question = state.get("question")
if question and self.initial_prompt is not None and not callable(self.initial_prompt):
msg = (
"When a question is provided, the prompt must be None or a callable. "
"String prompts cannot be used with questions since the question "
"becomes the system prompt."
)
raise ValueError(msg)
user_content = _format_documents_content(state["documents"])
# Use question as default system prompt if provided, otherwise use standard defaults
if question:
default_prompt = question
else:
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return await aresolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
def _get_refine_prompt(
self, state: ExtractionState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
# Validate prompt/question combination
question = state.get("question")
if question and self.refine_prompt is not None and not callable(self.refine_prompt):
msg = (
"When a question is provided, the refine_prompt must be None or a callable. "
"String prompts cannot be used with questions since the question "
"becomes the system prompt."
)
raise ValueError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Use question as default system prompt if provided, otherwise use standard defaults
if question:
default_prompt = f"{question} (Please refine the previous result with the new context.)"
else:
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return resolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_refine_prompt(
self, state: ExtractionState, runtime: Runtime
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt (async version)."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
# Validate prompt/question combination
question = state.get("question")
if question and self.refine_prompt is not None and not callable(self.refine_prompt):
msg = (
"When a question is provided, the refine_prompt must be None or a callable. "
"String prompts cannot be used with questions since the question "
"becomes the system prompt."
)
raise ValueError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Use question as default system prompt if provided, otherwise use standard defaults
if question:
default_prompt = f"{question} (Please refine the previous result with the new context.)"
else:
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return await aresolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
def create_document_processor_node(self) -> RunnableCallable:
"""Create the main document processing node.
The node handles both initial processing and refinement of results.
Refinement is done by providing the existing result and new context.
If the workflow is run with a checkpointer enabled, the result will be
persisted and available for a given thread id.
"""
def _process_node(
state: ExtractionState, runtime: Runtime, config: RunnableConfig
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = self._get_initial_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = self._get_refine_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _aprocess_node(
state: ExtractionState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = await self._aget_initial_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = await self._aget_refine_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_process_node,
_aprocess_node,
trace=False,
)
def build(
self,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for batch document extraction."""
builder = StateGraph(
ExtractionState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_edge(START, "process")
builder.add_node("process", self.create_document_processor_node())
return builder
def create_stuff_documents_chain(
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime], list[MessageLikeRepresentation]],
] = None,
refine_prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime], list[MessageLikeRepresentation]],
] = None,
context_schema: ContextT = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Create a stuff documents chain for processing documents.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided. The default prompts are optimized for summarization tasks, but
can be customized for other extraction tasks via the prompt parameters or
response_format.
Strategy:
1. Put all documents into the context window
2. Process all documents together in a single request
3. If an existing result is provided, refine it with all documents at once
4. Return the result
Important:
This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
Example:
```python
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
model = init_chat_model("anthropic:claude-sonnet-4-20250514")
builder = create_stuff_documents_chain(model)
chain = builder.compile()
docs = [
Document(page_content="First document content..."),
Document(page_content="Second document content..."),
Document(page_content="Third document content..."),
]
result = chain.invoke({"documents": docs})
print(result["result"])
# With a question
result = chain.invoke({
"documents": docs,
"question": "What are the main themes across these documents?"
})
print(result["result"])
# Structured summary/extraction by passing a schema
from pydantic import BaseModel
class Summary(BaseModel):
title: str
key_points: list[str]
builder = create_stuff_documents_chain(model, response_format=Summary)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Args:
model: The language model for document processing.
prompt: Prompt for initial processing. Can be:
- str: A system message string (cannot be used with questions)
- None: Use default system message (or question if provided)
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string (cannot be used with questions)
- None: Use default system message (or question if provided)
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to extract information.
Note:
This is a "stuff" documents chain that puts all documents into the context
window and processes them together. It supports refining existing results.
Default prompts are optimized for summarization but can be customized for
other tasks. Important: Does not control for context window size.
"""
extractor = _Extractor(
model,
prompt=prompt,
refine_prompt=refine_prompt,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_stuff_documents_chain"]

View File

@ -0,0 +1,5 @@
from langchain_core.documents import Document
__all__ = [
"Document",
]

View File

@ -9,7 +9,7 @@ requires-python = ">=3.9, <4.0"
dependencies = [
"langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8",
"langgraph>=0.5.4",
"langgraph>=0.6.0",
"pydantic>=2.7.4",
]

View File

@ -1667,7 +1667,7 @@ requires-dist = [
{ name = "langchain-text-splitters", editable = "../text-splitters" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langgraph", specifier = ">=0.5.4" },
{ name = "langgraph", specifier = ">=0.6.0" },
{ name = "pydantic", specifier = ">=2.7.4" },
]
provides-extras = ["anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
@ -1757,7 +1757,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.70"
version = "0.3.72"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },
@ -2133,7 +2133,7 @@ wheels = [
[[package]]
name = "langgraph"
version = "0.5.4"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
@ -2143,9 +2143,9 @@ dependencies = [
{ name = "pydantic" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/26/f01ae40ea26f8c723b6ec186869c80cc04de801630d99943018428b46105/langgraph-0.5.4.tar.gz", hash = "sha256:ab8f6b7b9c50fd2ae35a2efb072fbbfe79500dfc18071ac4ba6f5de5fa181931", size = 443149, upload-time = "2025-07-21T18:20:55.63Z" }
sdist = { url = "https://files.pythonhosted.org/packages/6e/12/4a30f766de571bfc319e70c2c0f4d050c0576e15c2249dc75ad122706a5d/langgraph-0.6.1.tar.gz", hash = "sha256:e4399ac5ad0b70f58fa28d6fe05a41b84c15959f270d6d1a86edab4e92ae148b", size = 449723, upload-time = "2025-07-29T20:45:28.438Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0d/82/15184e953234877107bad182b79c9111cb6ce6a79a97fdf36ebcaa11c0d0/langgraph-0.5.4-py3-none-any.whl", hash = "sha256:7122840225623e081be24ac30a691a24e5dac4c0361f593208f912838192d7f6", size = 143942, upload-time = "2025-07-21T18:20:54.442Z" },
{ url = "https://files.pythonhosted.org/packages/6e/90/e37e2bc19bee81f859bfd61526d977f54ec5d8e036fa091f2cfe4f19560b/langgraph-0.6.1-py3-none-any.whl", hash = "sha256:2736027faeb6cd5c0f1ab51a5345594cfcb5eb5beeb5ac1799a58fcecf4b4eae", size = 151874, upload-time = "2025-07-29T20:45:26.998Z" },
]
[[package]]
@ -2163,28 +2163,28 @@ wheels = [
[[package]]
name = "langgraph-prebuilt"
version = "0.5.2"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/11/98134c47832fbde0caf0e06f1a104577da9215c358d7854093c1d835b272/langgraph_prebuilt-0.5.2.tar.gz", hash = "sha256:2c900a5be0d6a93ea2521e0d931697cad2b646f1fcda7aa5c39d8d7539772465", size = 117808, upload-time = "2025-06-30T19:52:48.307Z" }
sdist = { url = "https://files.pythonhosted.org/packages/55/b6/d4f8e800bdfdd75486595203d5c622bba5f1098e4fd4220452c75568f2be/langgraph_prebuilt-0.6.1.tar.gz", hash = "sha256:574c409113e02d3c58157877c5ea638faa80647b259027647ab88830d7ecef00", size = 125057, upload-time = "2025-07-29T20:44:48.634Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/64/6bc45ab9e0e1112698ebff579fe21f5606ea65cd08266995a357e312a4d2/langgraph_prebuilt-0.5.2-py3-none-any.whl", hash = "sha256:1f4cd55deca49dffc3e5127eec12fcd244fc381321002f728afa88642d5ec59d", size = 23776, upload-time = "2025-06-30T19:52:47.494Z" },
{ url = "https://files.pythonhosted.org/packages/a6/df/cb4d73e99719b7ca0d42503d39dec49c67225779c6a1e689954a5604dbe6/langgraph_prebuilt-0.6.1-py3-none-any.whl", hash = "sha256:a3a970451371ec66509c6969505286a5d92132af7062d0b2b6dab08c2e27b50f", size = 28866, upload-time = "2025-07-29T20:44:47.72Z" },
]
[[package]]
name = "langgraph-sdk"
version = "0.1.74"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6d/f7/3807b72988f7eef5e0eb41e7e695eca50f3ed31f7cab5602db3b651c85ff/langgraph_sdk-0.1.74.tar.gz", hash = "sha256:7450e0db5b226cc2e5328ca22c5968725873630ef47c4206a30707cb25dc3ad6", size = 72190, upload-time = "2025-07-21T16:36:50.032Z" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/3e/3dc45dc7682c9940db9edaf8773d2e157397c5bd6881f6806808afd8731e/langgraph_sdk-0.2.0.tar.gz", hash = "sha256:cd8b5f6595e5571be5cbffd04cf936978ab8f5d1005517c99715947ef871e246", size = 72510, upload-time = "2025-07-22T17:31:06.745Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/1a/3eacc4df8127781ee4b0b1e5cad7dbaf12510f58c42cbcb9d1e2dba2a164/langgraph_sdk-0.1.74-py3-none-any.whl", hash = "sha256:3a265c3757fe0048adad4391d10486db63ef7aa5a2cbd22da22d4503554cb890", size = 50254, upload-time = "2025-07-21T16:36:49.134Z" },
{ url = "https://files.pythonhosted.org/packages/a5/03/a8ab0e8ea74be6058cb48bb1d85485b5c65d6ea183e3ee1aa8ca1ac73b3e/langgraph_sdk-0.2.0-py3-none-any.whl", hash = "sha256:150722264f225c4d47bbe7394676be102fdbf04c4400a0dd1bd41a70c6430cc7", size = 50569, upload-time = "2025-07-22T17:31:04.582Z" },
]
[[package]]