Merge branch 'master' into wip-v0.4

This commit is contained in:
Mason Daugherty 2025-08-07 15:33:12 -04:00 committed by GitHub
commit cbf4c0e565
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2340 additions and 391 deletions

View File

@ -388,11 +388,12 @@ jobs:
- name: Test against ${{ matrix.partner }}
if: startsWith(inputs.working-directory, 'libs/core')
run: |
# Identify latest tag
# Identify latest tag, excluding pre-releases
LATEST_PACKAGE_TAG="$(
git ls-remote --tags origin "langchain-${{ matrix.partner }}*" \
| awk '{print $2}' \
| sed 's|refs/tags/||' \
| grep -Ev '==[^=]*(\.?dev[0-9]*|\.?rc[0-9]*)$' \
| sort -Vr \
| head -n 1
)"

File diff suppressed because it is too large Load Diff

View File

@ -29,8 +29,8 @@
" Please refer to the instructions in:\n",
" [www.jaguardb.com](http://www.jaguardb.com)\n",
" For quick setup in docker environment:\n",
" docker pull jaguardb/jaguardb_with_http\n",
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb_with_http jaguardb/jaguardb_with_http\n",
" docker pull jaguardb/jaguardb\n",
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb jaguardb/jaguardb\n",
"\n",
"2. You must install the http client package for JaguarDB:\n",
" ```\n",

View File

@ -666,6 +666,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
converted_generations.append(chat_gen)
else:
# Already a ChatGeneration or other expected type
if hasattr(gen, "message") and isinstance(gen.message, AIMessage):
# We zero out cost on cache hits
gen.message = gen.message.model_copy(
update={
"usage_metadata": {
**(gen.message.usage_metadata or {}),
"total_cost": 0,
}
}
)
converted_generations.append(gen)
return converted_generations

View File

@ -458,3 +458,23 @@ def test_cleanup_serialized() -> None:
"name": "CustomChat",
"type": "constructor",
}
def test_token_costs_are_zeroed_out() -> None:
# We zero-out token costs for cache hits
local_cache = InMemoryCache()
messages = [
AIMessage(
content="Hello, how are you?",
usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
),
]
model = GenericFakeChatModel(messages=iter(messages), cache=local_cache)
first_response = model.invoke("Hello")
assert isinstance(first_response, AIMessage)
assert first_response.usage_metadata
second_response = model.invoke("Hello")
assert isinstance(second_response, AIMessage)
assert second_response.usage_metadata
assert second_response.usage_metadata["total_cost"] == 0 # type: ignore[typeddict-item]

View File

@ -0,0 +1,39 @@
"""Internal document utilities."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_core.documents import Document
def format_document_xml(doc: Document) -> str:
"""Format a document as XML-like structure for LLM consumption.
Args:
doc: Document to format
Returns:
Document wrapped in XML tags:
<document>
<id>...</id>
<content>...</content>
<metadata>...</metadata>
</document>
Note:
Does not generate valid XML or escape special characters.
Intended for semi-structured LLM input only.
"""
id_str = f"<id>{doc.id}</id>" if doc.id is not None else "<id></id>"
metadata_str = ""
if doc.metadata:
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
return (
f"<document>{id_str}"
f"<content>{doc.page_content}</content>"
f"{metadata_str}"
f"</document>"
)

View File

@ -0,0 +1,36 @@
"""Lazy import utilities."""
from importlib import import_module
from typing import Union
def import_attr(
attr_name: str,
module_name: Union[str, None],
package: Union[str, None],
) -> object:
"""Import an attribute from a module located in a package.
This utility function is used in custom __getattr__ methods within __init__.py
files to dynamically import attributes.
Args:
attr_name: The name of the attribute to import.
module_name: The name of the module to import from. If None, the attribute
is imported from the package itself.
package: The name of the package where the module is located.
"""
if module_name == "__module__" or module_name is None:
try:
result = import_module(f".{attr_name}", package=package)
except ModuleNotFoundError:
msg = f"module '{package!r}' has no attribute {attr_name!r}"
raise AttributeError(msg) from None
else:
try:
module = import_module(f".{module_name}", package=package)
except ModuleNotFoundError as err:
msg = f"module '{package!r}.{module_name!r}' not found ({err})"
raise ImportError(msg) from None
result = getattr(module, attr_name)
return result

View File

@ -0,0 +1,166 @@
"""Internal prompt resolution utilities.
This module provides utilities for resolving different types of prompt specifications
into standardized message formats for language models. It supports both synchronous
and asynchronous prompt resolution with automatic detection of callable types.
The module is designed to handle common prompt patterns across LangChain components,
particularly for summarization chains and other document processing workflows.
"""
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Callable, Union
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain_core.messages import MessageLikeRepresentation
from langgraph.runtime import Runtime
from langchain._internal._typing import ContextT, StateT
def resolve_prompt(
prompt: Union[
str,
None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
],
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Resolve a prompt specification into a list of messages.
Handles prompt resolution across different strategies. Supports callable functions,
string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable: Function taking (state, runtime) returning message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
def custom_prompt(state, runtime):
return [{"role": "system", "content": "Custom"}]
messages = resolve_prompt(custom_prompt, state, runtime, "content", "default")
messages = resolve_prompt("Custom system", state, runtime, "content", "default")
messages = resolve_prompt(None, state, runtime, "content", "Default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. String/None prompts create standard system + user
structure.
"""
if callable(prompt):
return prompt(state, runtime)
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]
async def aresolve_prompt(
prompt: Union[
str,
None,
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
Callable[
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
],
],
state: StateT,
runtime: Runtime[ContextT],
default_user_content: str,
default_system_content: str,
) -> list[MessageLikeRepresentation]:
"""Async version of resolve_prompt supporting both sync and async callables.
Handles prompt resolution across different strategies. Supports sync/async callable
functions, string system messages, and None for default behavior.
Args:
prompt: The prompt specification to resolve. Can be:
- Callable (sync): Function taking (state, runtime) returning message list.
- Callable (async): Async function taking (state, runtime) returning
awaitable message list.
- str: A system message string.
- None: Use the provided default system message.
state: Current state, passed to callable prompts.
runtime: LangGraph runtime instance, passed to callable prompts.
default_user_content: User content to include (e.g., document text).
default_system_content: Default system message when prompt is None.
Returns:
List of message dictionaries for language models, typically containing
a system message and user message with content.
Raises:
TypeError: If prompt type is not str, None, or callable.
Example:
```python
async def async_prompt(state, runtime):
return [{"role": "system", "content": "Async"}]
def sync_prompt(state, runtime):
return [{"role": "system", "content": "Sync"}]
messages = await aresolve_prompt(
async_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt(
sync_prompt, state, runtime, "content", "default"
)
messages = await aresolve_prompt("Custom", state, runtime, "content", "default")
```
Note:
Callable prompts have full control over message structure and content
parameter is ignored. Automatically detects and handles async
callables.
"""
if callable(prompt):
result = prompt(state, runtime)
# Check if the result is awaitable (async function)
if inspect.isawaitable(result):
return await result
return result
if isinstance(prompt, str):
system_msg = prompt
elif prompt is None:
system_msg = default_system_content
else:
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
raise TypeError(msg)
return [
{"role": "system", "content": system_msg},
{"role": "user", "content": default_user_content},
]

View File

@ -0,0 +1,65 @@
"""Private typing utilities for langchain."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
from langgraph.graph._node import StateNode
from pydantic import BaseModel
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from dataclasses import Field
class TypedDictLikeV1(Protocol):
"""Protocol to represent types that behave like TypedDicts.
Version 1: using `ClassVar` for keys.
"""
__required_keys__: ClassVar[frozenset[str]]
__optional_keys__: ClassVar[frozenset[str]]
class TypedDictLikeV2(Protocol):
"""Protocol to represent types that behave like TypedDicts.
Version 2: not using `ClassVar` for keys.
"""
__required_keys__: frozenset[str]
__optional_keys__: frozenset[str]
class DataclassLike(Protocol):
"""Protocol to represent types that behave like dataclasses.
Inspired by the private _DataclassT from dataclasses that uses a similar
protocol as a bound.
"""
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
StateLike: TypeAlias = Union[TypedDictLikeV1, TypedDictLikeV2, DataclassLike, BaseModel]
"""Type alias for state-like types.
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in
type checking.
"""
StateT = TypeVar("StateT", bound=StateLike)
"""Type variable used to represent the state in a graph."""
ContextT = TypeVar("ContextT", bound=Union[StateLike, None])
"""Type variable for context types."""
__all__ = [
"ContextT",
"StateLike",
"StateNode",
"StateT",
]

View File

@ -0,0 +1,7 @@
# Re-exporting internal utilities from LangGraph for internal use in LangChain.
# A different wrapper needs to be created for this purpose in LangChain.
from langgraph._internal._runnable import RunnableCallable
__all__ = [
"RunnableCallable",
]

View File

@ -0,0 +1,9 @@
from langchain.chains.documents import (
create_map_reduce_chain,
create_stuff_documents_chain,
)
__all__ = [
"create_map_reduce_chain",
"create_stuff_documents_chain",
]

View File

@ -0,0 +1,17 @@
"""Document extraction chains.
This module provides different strategies for extracting information from collections
of documents using LangGraph and modern language models.
Available Strategies:
- Stuff: Processes all documents together in a single context window
- Map-Reduce: Processes documents in parallel (map), then combines results (reduce)
"""
from langchain.chains.documents.map_reduce import create_map_reduce_chain
from langchain.chains.documents.stuff import create_stuff_documents_chain
__all__ = [
"create_map_reduce_chain",
"create_stuff_documents_chain",
]

View File

@ -0,0 +1,586 @@
"""Map-Reduce Extraction Implementation using LangGraph Send API."""
from __future__ import annotations
import operator
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Generic,
Literal,
Optional,
Union,
cast,
)
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from typing_extensions import NotRequired, TypedDict
from langchain._internal._documents import format_document_xml
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
from langchain._internal._typing import ContextT, StateNode
from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
# Pycharm is unable to identify that AIMessage is used in the cast below
from langchain_core.messages import (
AIMessage,
MessageLikeRepresentation,
)
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
class ExtractionResult(TypedDict):
"""Result from processing a document or group of documents."""
indexes: list[int]
"""Document indexes that contributed to this result."""
result: Any
"""Extracted result from the document(s)."""
class MapReduceState(TypedDict):
"""State for map-reduce extraction chain.
This state tracks the map-reduce process where documents are processed
in parallel during the map phase, then combined in the reduce phase.
"""
documents: list[Document]
"""List of documents to process."""
map_results: Annotated[list[ExtractionResult], operator.add]
"""Individual results from the map phase."""
result: NotRequired[Any]
"""Final combined result from the reduce phase if applicable."""
# The payload for the map phase is a list of documents and their indexes.
# The current implementation only supports a single document per map operation,
# but the structure allows for future expansion to process a group of documents.
# A user would provide an input split function that returns groups of documents
# to process together, if desired.
class MapState(TypedDict):
"""State for individual map operations."""
documents: list[Document]
"""List of documents to process in map phase."""
indexes: list[int]
"""List of indexes of the documents in the original list."""
class InputSchema(TypedDict):
"""Input schema for the map-reduce extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
class OutputSchema(TypedDict):
"""Output schema for the map-reduce extraction chain.
Defines the format of the final result returned by the chain.
"""
map_results: list[ExtractionResult]
"""List of individual extraction results from the map phase."""
result: Any
"""Final combined result from all documents."""
class MapReduceNodeUpdate(TypedDict):
"""Update returned by map-reduce nodes."""
map_results: NotRequired[list[ExtractionResult]]
"""Updated results after map phase."""
result: NotRequired[Any]
"""Final result after reduce phase."""
class _MapReduceExtractor(Generic[ContextT]):
"""Map-reduce extraction implementation using LangGraph Send API.
This implementation uses a language model to process documents through up
to two phases:
1. **Map Phase**: Each document is processed independently by the LLM using
the configured map_prompt to generate individual extraction results.
2. **Reduce Phase (Optional)**: Individual results can optionally be
combined using either:
- The default LLM-based reducer with the configured reduce_prompt
- A custom reducer function (which can be non-LLM based)
- Skipped entirely by setting reduce=None
The map phase processes documents in parallel for efficiency, making this approach
well-suited for large document collections. The reduce phase is flexible and can be
customized or omitted based on your specific requirements.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[
[MapState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce_prompt: Union[
str,
None,
Callable[
[MapReduceState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the MapReduceExtractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
if (reduce is None or callable(reduce)) and reduce_prompt is not None:
msg = (
"reduce_prompt must be None when reduce is None or a custom "
"callable. Custom reduce functions handle their own logic and "
"should not use reduce_prompt."
)
raise ValueError(msg)
self.response_format = response_format
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.map_prompt = map_prompt
self.reduce_prompt = reduce_prompt
self.reduce = reduce
self.context_schema = context_schema
def _get_map_prompt(
self, state: MapState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents."""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return resolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_map_prompt(
self, state: MapState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for processing documents in the map phase.
Async version.
"""
documents = state["documents"]
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
default_system = (
"You are a helpful assistant that processes documents. "
"Please process the following documents and provide a result."
)
return await aresolve_prompt(
self.map_prompt,
state,
runtime,
user_content,
default_system,
)
def _get_reduce_prompt(
self, state: MapReduceState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results.
Combines map results in the reduce phase.
"""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return resolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
async def _aget_reduce_prompt(
self, state: MapReduceState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the LLM prompt for combining individual results.
Async version of reduce phase.
"""
map_results = state.get("map_results", [])
if not map_results:
msg = (
"Internal programming error: Results must exist when reducing. "
"This indicates that the reduce node was reached without "
"first processing the map nodes, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
results_text = "\n\n".join(
f"Result {i + 1} (from documents "
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
for i, result in enumerate(map_results)
)
user_content = (
f"Please combine the following results into a single, "
f"comprehensive result:\n\n{results_text}"
)
default_system = (
"You are a helpful assistant that combines multiple results. "
"Given several individual results, create a single comprehensive "
"result that captures the key information from all inputs while "
"maintaining conciseness and coherence."
)
return await aresolve_prompt(
self.reduce_prompt,
state,
runtime,
user_content,
default_system,
)
def create_map_node(self) -> RunnableCallable:
"""Create a LangGraph node that processes individual documents using the LLM."""
def _map_node(
state: MapState, runtime: Runtime[ContextT], config: RunnableConfig
) -> dict[str, list[ExtractionResult]]:
prompt = self._get_map_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
async def _amap_node(
state: MapState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> dict[str, list[ExtractionResult]]:
prompt = await self._aget_map_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
extraction_result: ExtractionResult = {
"indexes": state["indexes"],
"result": result,
}
return {"map_results": [extraction_result]}
return RunnableCallable(
_map_node,
_amap_node,
trace=False,
)
def create_reduce_node(self) -> RunnableCallable:
"""Create a LangGraph node that combines individual results using the LLM."""
def _reduce_node(
state: MapReduceState, runtime: Runtime[ContextT], config: RunnableConfig
) -> MapReduceNodeUpdate:
prompt = self._get_reduce_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _areduce_node(
state: MapReduceState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> MapReduceNodeUpdate:
prompt = await self._aget_reduce_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_reduce_node,
_areduce_node,
trace=False,
)
def continue_to_map(self, state: MapReduceState) -> list[Send]:
"""Generate Send objects for parallel map operations."""
return [
Send("map_process", {"documents": [doc], "indexes": [i]})
for i, doc in enumerate(state["documents"])
]
def build(
self,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for map-reduce summarization."""
builder = StateGraph(
MapReduceState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_node("map_process", self.create_map_node())
builder.add_edge(START, "continue_to_map")
# Add-conditional edges doesn't explicitly type Send
builder.add_conditional_edges(
"continue_to_map",
self.continue_to_map, # type: ignore[arg-type]
["map_process"],
)
if self.reduce is None:
builder.add_edge("map_process", END)
elif self.reduce == "default_reducer":
builder.add_node("reduce_process", self.create_reduce_node())
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
else:
reduce_node = cast("StateNode", self.reduce)
# The type is ignored here. Requires parameterizing with generics.
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
builder.add_edge("map_process", "reduce_process")
builder.add_edge("reduce_process", END)
return builder
def create_map_reduce_chain(
model: Union[BaseChatModel, str],
*,
map_prompt: Union[
str,
None,
Callable[[MapState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
reduce_prompt: Union[
str,
None,
Callable[[MapReduceState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
reduce: Union[
Literal["default_reducer"],
None,
StateNode,
] = "default_reducer",
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
"""Create a map-reduce document extraction chain.
This implementation uses a language model to extract information from documents
through a flexible approach that efficiently handles large document collections
by processing documents in parallel.
**Processing Flow:**
1. **Map Phase**: Each document is independently processed by the LLM
using the map_prompt to extract relevant information and generate
individual results.
2. **Reduce Phase (Optional)**: Individual extraction results can
optionally be combined using:
- The default LLM-based reducer with reduce_prompt (default behavior)
- A custom reducer function (can be non-LLM based)
- Skipped entirely by setting reduce=None
3. **Output**: Returns the individual map results and optionally the final
combined result.
Example:
>>> from langchain_anthropic import ChatAnthropic
>>> from langchain_core.documents import Document
>>>
>>> model = ChatAnthropic(
... model="claude-sonnet-4-20250514",
... temperature=0,
... max_tokens=62_000,
... timeout=None,
... max_retries=2,
... )
>>> builder = create_map_reduce_chain(model)
>>> chain = builder.compile()
>>> docs = [
... Document(page_content="First document content..."),
... Document(page_content="Second document content..."),
... Document(page_content="Third document content..."),
... ]
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with string model:
>>> builder = create_map_reduce_chain("anthropic:claude-sonnet-4-20250514")
>>> chain = builder.compile()
>>> result = chain.invoke({"documents": docs})
>>> print(result["result"])
Example with structured output:
```python
from pydantic import BaseModel
class ExtractionModel(BaseModel):
title: str
key_points: list[str]
conclusion: str
builder = create_map_reduce_chain(
model,
response_format=ExtractionModel
)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Example skipping the reduce phase:
```python
# Only perform map phase, skip combining results
builder = create_map_reduce_chain(model, reduce=None)
chain = builder.compile()
result = chain.invoke({"documents": docs})
# result["result"] will be None, only map_results are available
for map_result in result["map_results"]:
print(f"Document {map_result['indexes'][0]}: {map_result['result']}")
```
Example with custom reducer:
```python
def custom_aggregator(state, runtime):
# Custom non-LLM based reduction logic
map_results = state["map_results"]
combined_text = " | ".join(r["result"] for r in map_results)
word_count = len(combined_text.split())
return {
"result": f"Combined {len(map_results)} results with "
f"{word_count} total words"
}
builder = create_map_reduce_chain(model, reduce=custom_aggregator)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"]) # Custom aggregated result
```
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
map_prompt: Prompt for individual document processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce_prompt: Prompt for combining results. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
reduce: Controls the reduce behavior. Can be:
- "default_reducer": Use the default LLM-based reduce step
- None: Skip the reduce step entirely
- Callable: Custom reduce function (sync or async)
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to get map-reduce
extraction results.
Note:
This implementation is well-suited for large document collections as it
processes documents in parallel during the map phase. The Send API enables
efficient parallelization while maintaining clean state management.
"""
extractor = _MapReduceExtractor(
model,
map_prompt=map_prompt,
reduce_prompt=reduce_prompt,
reduce=reduce,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_map_reduce_chain"]

View File

@ -0,0 +1,473 @@
"""Stuff documents chain for processing documents by putting them all in context."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Union,
cast,
)
# Used not only for type checking, but is fetched at runtime by Pydantic.
from langchain_core.documents import Document as Document # noqa: TC002
from langgraph.graph import START, StateGraph
from typing_extensions import NotRequired, TypedDict
from langchain._internal._documents import format_document_xml
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
from langchain._internal._typing import ContextT
from langchain._internal._utils import RunnableCallable
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
# Used for type checking, but IDEs may not recognize it inside the cast.
from langchain_core.messages import AIMessage as AIMessage
from langchain_core.messages import MessageLikeRepresentation
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from pydantic import BaseModel
# Default system prompts
DEFAULT_INIT_PROMPT = (
"You are a helpful assistant that summarizes text. "
"Please provide a concise summary of the documents "
"provided by the user."
)
DEFAULT_STRUCTURED_INIT_PROMPT = (
"You are a helpful assistant that extracts structured information from documents. "
"Use the provided content and optional question to generate your output, formatted "
"according to the predefined schema."
)
DEFAULT_REFINE_PROMPT = (
"You are a helpful assistant that refines summaries. "
"Given an existing summary and new context, produce a refined summary "
"that incorporates the new information while maintaining conciseness."
)
DEFAULT_STRUCTURED_REFINE_PROMPT = (
"You are a helpful assistant refining structured information extracted "
"from documents. "
"You are given a previous result and new document context. "
"Update the output to reflect the new context, staying consistent with "
"the expected schema."
)
def _format_documents_content(documents: list[Document]) -> str:
"""Format documents into content string.
Args:
documents: List of documents to format.
Returns:
Formatted document content string.
"""
return "\n\n".join(format_document_xml(doc) for doc in documents)
class ExtractionState(TypedDict):
"""State for extraction chain.
This state tracks the extraction process where documents
are processed in batch, with the result being refined if needed.
"""
documents: list[Document]
"""List of documents to process."""
result: NotRequired[Any]
"""Current result, refined with each document."""
class InputSchema(TypedDict):
"""Input schema for the extraction chain.
Defines the expected input format when invoking the extraction chain.
"""
documents: list[Document]
"""List of documents to process."""
result: NotRequired[Any]
"""Existing result to refine (optional)."""
class OutputSchema(TypedDict):
"""Output schema for the extraction chain.
Defines the format of the final result returned by the chain.
"""
result: Any
"""Result from processing the documents."""
class ExtractionNodeUpdate(TypedDict):
"""Update returned by processing nodes."""
result: NotRequired[Any]
"""Updated result after processing a document."""
class _Extractor(Generic[ContextT]):
"""Stuff documents chain implementation.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided.
Important: This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
"""
def __init__(
self,
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
refine_prompt: Union[
str,
None,
Callable[
[ExtractionState, Runtime[ContextT]],
list[MessageLikeRepresentation],
],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> None:
"""Initialize the Extractor.
Args:
model: The language model either a chat model instance
(e.g., `ChatAnthropic()`) or string identifier
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
prompt: Prompt for initial processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
"""
self.response_format = response_format
if isinstance(model, str):
model = init_chat_model(model)
self.model = (
model.with_structured_output(response_format) if response_format else model
)
self.initial_prompt = prompt
self.refine_prompt = refine_prompt
self.context_schema = context_schema
def _get_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt."""
user_content = _format_documents_content(state["documents"])
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return resolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_initial_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the initial extraction prompt (async version)."""
user_content = _format_documents_content(state["documents"])
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_INIT_PROMPT
if self.response_format
else DEFAULT_INIT_PROMPT
)
return await aresolve_prompt(
self.initial_prompt,
state,
runtime,
user_content,
default_prompt,
)
def _get_refine_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return resolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
async def _aget_refine_prompt(
self, state: ExtractionState, runtime: Runtime[ContextT]
) -> list[MessageLikeRepresentation]:
"""Generate the refinement prompt (async version)."""
# Result should be guaranteed to exist at refinement stage
if "result" not in state or state["result"] == "":
msg = (
"Internal programming error: Result must exist when refining. "
"This indicates that the refinement node was reached without "
"first processing the initial result node, which violates "
"the expected graph execution order."
)
raise AssertionError(msg)
new_context = _format_documents_content(state["documents"])
user_content = (
f"Previous result:\n{state['result']}\n\n"
f"New context:\n{new_context}\n\n"
f"Please provide a refined result."
)
# Choose default prompt based on structured output format
default_prompt = (
DEFAULT_STRUCTURED_REFINE_PROMPT
if self.response_format
else DEFAULT_REFINE_PROMPT
)
return await aresolve_prompt(
self.refine_prompt,
state,
runtime,
user_content,
default_prompt,
)
def create_document_processor_node(self) -> RunnableCallable:
"""Create the main document processing node.
The node handles both initial processing and refinement of results.
Refinement is done by providing the existing result and new context.
If the workflow is run with a checkpointer enabled, the result will be
persisted and available for a given thread id.
"""
def _process_node(
state: ExtractionState, runtime: Runtime[ContextT], config: RunnableConfig
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = self._get_initial_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = self._get_refine_prompt(state, runtime)
response = cast("AIMessage", self.model.invoke(prompt, config=config))
result = response if self.response_format else response.text()
return {"result": result}
async def _aprocess_node(
state: ExtractionState,
runtime: Runtime[ContextT],
config: RunnableConfig,
) -> ExtractionNodeUpdate:
# Handle empty document list
if not state["documents"]:
return {}
# Determine if this is initial processing or refinement
if "result" not in state or state["result"] == "":
# Initial processing
prompt = await self._aget_initial_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
# Refinement
prompt = await self._aget_refine_prompt(state, runtime)
response = cast(
"AIMessage", await self.model.ainvoke(prompt, config=config)
)
result = response if self.response_format else response.text()
return {"result": result}
return RunnableCallable(
_process_node,
_aprocess_node,
trace=False,
)
def build(
self,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Build and compile the LangGraph for batch document extraction."""
builder = StateGraph(
ExtractionState,
context_schema=self.context_schema,
input_schema=InputSchema,
output_schema=OutputSchema,
)
builder.add_edge(START, "process")
builder.add_node("process", self.create_document_processor_node())
return builder
def create_stuff_documents_chain(
model: Union[BaseChatModel, str],
*,
prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
refine_prompt: Union[
str,
None,
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
] = None,
context_schema: type[ContextT] | None = None,
response_format: Optional[type[BaseModel]] = None,
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
"""Create a stuff documents chain for processing documents.
This chain works by putting all the documents in the batch into the context
window of the language model. It processes all documents together in a single
request for extracting information or summaries. Can refine existing results
when provided. The default prompts are optimized for summarization tasks, but
can be customized for other extraction tasks via the prompt parameters or
response_format.
Strategy:
1. Put all documents into the context window
2. Process all documents together in a single request
3. If an existing result is provided, refine it with all documents at once
4. Return the result
Important:
This chain does not attempt to control for the size of the context
window of the LLM. Ensure your documents fit within the model's context limits.
Example:
```python
from langchain.chat_models import init_chat_model
from langchain_core.documents import Document
model = init_chat_model("anthropic:claude-sonnet-4-20250514")
builder = create_stuff_documents_chain(model)
chain = builder.compile()
docs = [
Document(page_content="First document content..."),
Document(page_content="Second document content..."),
Document(page_content="Third document content..."),
]
result = chain.invoke({"documents": docs})
print(result["result"])
# Structured summary/extraction by passing a schema
from pydantic import BaseModel
class Summary(BaseModel):
title: str
key_points: list[str]
builder = create_stuff_documents_chain(model, response_format=Summary)
chain = builder.compile()
result = chain.invoke({"documents": docs})
print(result["result"].title) # Access structured fields
```
Args:
model: The language model for document processing.
prompt: Prompt for initial processing. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
refine_prompt: Prompt for refinement steps. Can be:
- str: A system message string
- None: Use default system message
- Callable: A function that takes (state, runtime) and returns messages
context_schema: Optional context schema for the LangGraph runtime.
response_format: Optional pydantic BaseModel for structured output.
Returns:
A LangGraph that can be invoked with documents to extract information.
Note:
This is a "stuff" documents chain that puts all documents into the context
window and processes them together. It supports refining existing results.
Default prompts are optimized for summarization but can be customized for
other tasks. Important: Does not control for context window size.
"""
extractor = _Extractor(
model,
prompt=prompt,
refine_prompt=refine_prompt,
context_schema=context_schema,
response_format=response_format,
)
return extractor.build()
__all__ = ["create_stuff_documents_chain"]

View File

@ -0,0 +1,5 @@
from langchain_core.documents import Document
__all__ = [
"Document",
]

View File

@ -9,7 +9,7 @@ requires-python = ">=3.9, <4.0"
dependencies = [
"langchain-core<1.0.0,>=0.3.66",
"langchain-text-splitters<1.0.0,>=0.3.8",
"langgraph>=0.5.4",
"langgraph>=0.6.0",
"pydantic>=2.7.4",
]
@ -123,6 +123,7 @@ ignore = [
"UP007", # pyupgrade: non-pep604-annotation-union
"PLC0415", # Imports should be at the top. Not always desirable
"PLR0913", # Too many arguments in function definition
"PLC0414", # Inconsistent with how type checkers expect to be notified of intentional re-exports
]
unfixable = ["B028"] # People should intentionally tune the stacklevel

View File

@ -1667,7 +1667,7 @@ requires-dist = [
{ name = "langchain-text-splitters", editable = "../text-splitters" },
{ name = "langchain-together", marker = "extra == 'together'" },
{ name = "langchain-xai", marker = "extra == 'xai'" },
{ name = "langgraph", specifier = ">=0.5.4" },
{ name = "langgraph", specifier = ">=0.6.0" },
{ name = "pydantic", specifier = ">=2.7.4" },
]
provides-extras = ["anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
@ -1757,7 +1757,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.70"
version = "0.3.72"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },
@ -2133,7 +2133,7 @@ wheels = [
[[package]]
name = "langgraph"
version = "0.5.4"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
@ -2143,9 +2143,9 @@ dependencies = [
{ name = "pydantic" },
{ name = "xxhash" },
]
sdist = { url = "https://files.pythonhosted.org/packages/99/26/f01ae40ea26f8c723b6ec186869c80cc04de801630d99943018428b46105/langgraph-0.5.4.tar.gz", hash = "sha256:ab8f6b7b9c50fd2ae35a2efb072fbbfe79500dfc18071ac4ba6f5de5fa181931", size = 443149, upload-time = "2025-07-21T18:20:55.63Z" }
sdist = { url = "https://files.pythonhosted.org/packages/6e/12/4a30f766de571bfc319e70c2c0f4d050c0576e15c2249dc75ad122706a5d/langgraph-0.6.1.tar.gz", hash = "sha256:e4399ac5ad0b70f58fa28d6fe05a41b84c15959f270d6d1a86edab4e92ae148b", size = 449723, upload-time = "2025-07-29T20:45:28.438Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0d/82/15184e953234877107bad182b79c9111cb6ce6a79a97fdf36ebcaa11c0d0/langgraph-0.5.4-py3-none-any.whl", hash = "sha256:7122840225623e081be24ac30a691a24e5dac4c0361f593208f912838192d7f6", size = 143942, upload-time = "2025-07-21T18:20:54.442Z" },
{ url = "https://files.pythonhosted.org/packages/6e/90/e37e2bc19bee81f859bfd61526d977f54ec5d8e036fa091f2cfe4f19560b/langgraph-0.6.1-py3-none-any.whl", hash = "sha256:2736027faeb6cd5c0f1ab51a5345594cfcb5eb5beeb5ac1799a58fcecf4b4eae", size = 151874, upload-time = "2025-07-29T20:45:26.998Z" },
]
[[package]]
@ -2163,28 +2163,28 @@ wheels = [
[[package]]
name = "langgraph-prebuilt"
version = "0.5.2"
version = "0.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langchain-core" },
{ name = "langgraph-checkpoint" },
]
sdist = { url = "https://files.pythonhosted.org/packages/bb/11/98134c47832fbde0caf0e06f1a104577da9215c358d7854093c1d835b272/langgraph_prebuilt-0.5.2.tar.gz", hash = "sha256:2c900a5be0d6a93ea2521e0d931697cad2b646f1fcda7aa5c39d8d7539772465", size = 117808, upload-time = "2025-06-30T19:52:48.307Z" }
sdist = { url = "https://files.pythonhosted.org/packages/55/b6/d4f8e800bdfdd75486595203d5c622bba5f1098e4fd4220452c75568f2be/langgraph_prebuilt-0.6.1.tar.gz", hash = "sha256:574c409113e02d3c58157877c5ea638faa80647b259027647ab88830d7ecef00", size = 125057, upload-time = "2025-07-29T20:44:48.634Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/64/6bc45ab9e0e1112698ebff579fe21f5606ea65cd08266995a357e312a4d2/langgraph_prebuilt-0.5.2-py3-none-any.whl", hash = "sha256:1f4cd55deca49dffc3e5127eec12fcd244fc381321002f728afa88642d5ec59d", size = 23776, upload-time = "2025-06-30T19:52:47.494Z" },
{ url = "https://files.pythonhosted.org/packages/a6/df/cb4d73e99719b7ca0d42503d39dec49c67225779c6a1e689954a5604dbe6/langgraph_prebuilt-0.6.1-py3-none-any.whl", hash = "sha256:a3a970451371ec66509c6969505286a5d92132af7062d0b2b6dab08c2e27b50f", size = 28866, upload-time = "2025-07-29T20:44:47.72Z" },
]
[[package]]
name = "langgraph-sdk"
version = "0.1.74"
version = "0.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/6d/f7/3807b72988f7eef5e0eb41e7e695eca50f3ed31f7cab5602db3b651c85ff/langgraph_sdk-0.1.74.tar.gz", hash = "sha256:7450e0db5b226cc2e5328ca22c5968725873630ef47c4206a30707cb25dc3ad6", size = 72190, upload-time = "2025-07-21T16:36:50.032Z" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/3e/3dc45dc7682c9940db9edaf8773d2e157397c5bd6881f6806808afd8731e/langgraph_sdk-0.2.0.tar.gz", hash = "sha256:cd8b5f6595e5571be5cbffd04cf936978ab8f5d1005517c99715947ef871e246", size = 72510, upload-time = "2025-07-22T17:31:06.745Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/1a/3eacc4df8127781ee4b0b1e5cad7dbaf12510f58c42cbcb9d1e2dba2a164/langgraph_sdk-0.1.74-py3-none-any.whl", hash = "sha256:3a265c3757fe0048adad4391d10486db63ef7aa5a2cbd22da22d4503554cb890", size = 50254, upload-time = "2025-07-21T16:36:49.134Z" },
{ url = "https://files.pythonhosted.org/packages/a5/03/a8ab0e8ea74be6058cb48bb1d85485b5c65d6ea183e3ee1aa8ca1ac73b3e/langgraph_sdk-0.2.0-py3-none-any.whl", hash = "sha256:150722264f225c4d47bbe7394676be102fdbf04c4400a0dd1bd41a70c6430cc7", size = 50569, upload-time = "2025-07-22T17:31:04.582Z" },
]
[[package]]

View File

@ -2731,6 +2731,31 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``.
Using ``model_kwargs`` for non-OpenAI parameters will cause API errors.
.. dropdown:: Prompt caching optimization
For high-volume applications with repetitive prompts, use ``prompt_cache_key``
per-invocation to improve cache hit rates and reduce costs:
.. code-block:: python
llm = ChatOpenAI(model="gpt-4o-mini")
response = llm.invoke(
messages,
prompt_cache_key="example-key-a", # Routes to same machine for cache hits
)
customer_response = llm.invoke(messages, prompt_cache_key="example-key-b")
support_response = llm.invoke(messages, prompt_cache_key="example-key-c")
# Dynamic cache keys based on context
cache_key = f"example-key-{dynamic_suffix}"
response = llm.invoke(messages, prompt_cache_key=cache_key)
Cache keys help ensure requests with the same prompt prefix are routed to
machines with existing cache, providing cost reduction and latency improvement on
cached tokens.
""" # noqa: E501
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
@ -3716,6 +3741,20 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
return input_
def _get_output_text(response: Response) -> str:
"""OpenAI SDK deleted response.output_text in 1.99.2"""
if hasattr(response, "output_text"):
return response.output_text
texts: list[str] = []
for output in response.output:
if output.type == "message":
for content in output.content:
if content.type == "output_text":
texts.append(content.text)
return "".join(texts)
def _construct_lc_result_from_responses_api(
response: Response,
schema: Optional[type[_BM]] = None,
@ -3830,17 +3869,18 @@ def _construct_lc_result_from_responses_api(
# text_format=Foo,
# stream=True, # <-- errors
# )
output_text = _get_output_text(response)
if (
schema is not None
and "parsed" not in additional_kwargs
and response.output_text # tool calls can generate empty output text
and output_text # tool calls can generate empty output text
and response.text
and (text_config := response.text.model_dump())
and (format_ := text_config.get("format", {}))
and (format_.get("type") == "json_schema")
):
try:
parsed_dict = json.loads(response.output_text)
parsed_dict = json.loads(output_text)
if schema and _is_pydantic_class(schema):
parsed = schema(**parsed_dict)
else:

View File

@ -1112,3 +1112,46 @@ def test_tools_and_structured_output() -> None:
assert isinstance(aggregated["raw"], AIMessage)
assert aggregated["raw"].tool_calls
assert aggregated["parsed"] is None
@pytest.mark.scheduled
def test_prompt_cache_key_invoke() -> None:
"""Test that prompt_cache_key works with invoke calls."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20)
messages = [HumanMessage("Say hello")]
# Test that invoke works with prompt_cache_key parameter
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.content) > 0
# Test that subsequent call with same cache key also works
response2 = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response2, AIMessage)
assert isinstance(response2.content, str)
assert len(response2.content) > 0
@pytest.mark.scheduled
def test_prompt_cache_key_usage_methods_integration() -> None:
"""Integration test for prompt_cache_key usage methods."""
messages = [HumanMessage("Say hi")]
# Test keyword argument method
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
# Test model-level via model_kwargs
chat_model_level = ChatOpenAI(
model="gpt-4o-mini",
max_completion_tokens=10,
model_kwargs={"prompt_cache_key": "integration-model-level-v1"},
)
response_model_level = chat_model_level.invoke(messages)
assert isinstance(response_model_level, AIMessage)
assert isinstance(response_model_level.content, str)

View File

@ -0,0 +1,84 @@
"""Unit tests for prompt_cache_key parameter."""
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
def test_prompt_cache_key_parameter_inclusion() -> None:
"""Test that prompt_cache_key parameter is properly included in request payload."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
payload = chat._get_request_payload(messages, prompt_cache_key="test-cache-key")
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "test-cache-key"
def test_prompt_cache_key_parameter_exclusion() -> None:
"""Test that prompt_cache_key parameter behavior matches OpenAI API."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
# Test with explicit None (OpenAI should accept None values (marked Optional))
payload = chat._get_request_payload(messages, prompt_cache_key=None)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] is None
def test_prompt_cache_key_per_call() -> None:
"""Test that prompt_cache_key can be passed per-call with different values."""
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
messages = [HumanMessage("Hello")]
# Test different cache keys per call
payload1 = chat._get_request_payload(messages, prompt_cache_key="cache-v1")
payload2 = chat._get_request_payload(messages, prompt_cache_key="cache-v2")
assert payload1["prompt_cache_key"] == "cache-v1"
assert payload2["prompt_cache_key"] == "cache-v2"
# Test dynamic cache key assignment
cache_keys = ["customer-v1", "support-v1", "feedback-v1"]
for cache_key in cache_keys:
payload = chat._get_request_payload(messages, prompt_cache_key=cache_key)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == cache_key
def test_prompt_cache_key_model_kwargs() -> None:
"""Test prompt_cache_key via model_kwargs and method precedence."""
messages = [HumanMessage("Hello world")]
# Test model-level via model_kwargs
chat = ChatOpenAI(
model="gpt-4o-mini",
max_completion_tokens=10,
model_kwargs={"prompt_cache_key": "model-level-cache"},
)
payload = chat._get_request_payload(messages)
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "model-level-cache"
# Test that per-call cache key overrides model-level
payload_override = chat._get_request_payload(
messages, prompt_cache_key="per-call-cache"
)
assert payload_override["prompt_cache_key"] == "per-call-cache"
def test_prompt_cache_key_responses_api() -> None:
"""Test that prompt_cache_key works with Responses API."""
chat = ChatOpenAI(
model="gpt-4o-mini", use_responses_api=True, max_completion_tokens=10
)
messages = [HumanMessage("Hello")]
payload = chat._get_request_payload(
messages, prompt_cache_key="responses-api-cache-v1"
)
# prompt_cache_key should be present regardless of API type
assert "prompt_cache_key" in payload
assert payload["prompt_cache_key"] == "responses-api-cache-v1"

View File

@ -995,7 +995,7 @@ wheels = [
[[package]]
name = "openai"
version = "1.98.0"
version = "1.99.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -1007,9 +1007,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d8/9d/52eadb15c92802711d6b6cf00df3a6d0d18b588f4c5ba5ff210c6419fc03/openai-1.98.0.tar.gz", hash = "sha256:3ee0fcc50ae95267fd22bd1ad095ba5402098f3df2162592e68109999f685427", size = 496695, upload-time = "2025-07-30T12:48:03.701Z" }
sdist = { url = "https://files.pythonhosted.org/packages/de/2c/8cd1684364551237a5e6db24ce25c25ff54efcf1805b39110ec713dc2972/openai-1.99.2.tar.gz", hash = "sha256:118075b48109aa237636607b1346cf03b37cb9d74b0414cb11095850a0a22c96", size = 504752, upload-time = "2025-08-07T17:16:14.668Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/fe/f64631075b3d63a613c0d8ab761d5941631a470f6fa87eaaee1aa2b4ec0c/openai-1.98.0-py3-none-any.whl", hash = "sha256:b99b794ef92196829120e2df37647722104772d2a74d08305df9ced5f26eae34", size = 767713, upload-time = "2025-07-30T12:48:01.264Z" },
{ url = "https://files.pythonhosted.org/packages/b7/06/f3338c1b8685dc1634fa5174dc5ba2d3eecc7887c9fc539bb5da6f75ebb3/openai-1.99.2-py3-none-any.whl", hash = "sha256:110d85b8ed400e1d7b02f8db8e245bd757bcde347cb6923155f42cd66a10aa0b", size = 785594, upload-time = "2025-08-07T17:16:13.083Z" },
]
[[package]]

View File

@ -24,8 +24,14 @@ class RetrieversIntegrationTests(BaseStandardTests):
@property
@abstractmethod
def retriever_query_example(self) -> str:
"""Returns a str representing the "query" of an example retriever call."""
...
"""Returns a str representing the ``query`` of an example retriever call."""
@property
def num_results_arg_name(self) -> str:
"""Returns the name of the parameter for the number of results returned.
Usually something like ``k`` or ``top_k``."""
return "k"
@pytest.fixture
def retriever(self) -> BaseRetriever:
@ -33,14 +39,34 @@ class RetrieversIntegrationTests(BaseStandardTests):
return self.retriever_constructor(**self.retriever_constructor_params)
def test_k_constructor_param(self) -> None:
"""Test that the retriever constructor accepts a k parameter, representing
"""Test the number of results constructor parameter.
Test that the retriever constructor accepts a parameter representing
the number of documents to return.
By default, the parameter tested is named ``k``, but it can be overridden by
setting the ``num_results_arg_name`` property.
.. note::
If the retriever doesn't support configuring the number of results returned
via the constructor, this test can be skipped using a pytest ``xfail`` on
the test class:
.. code-block:: python
@pytest.mark.xfail(
reason="This retriever doesn't support setting "
"the number of results via the constructor."
)
def test_k_constructor_param(self) -> None:
raise NotImplementedError
.. dropdown:: Troubleshooting
If this test fails, either the retriever constructor does not accept a k
parameter, or the retriever does not return the correct number of documents
(`k`) when it is set.
If this test fails, the retriever constructor does not accept a number
of results parameter, or the retriever does not return the correct number
of documents ( of the one set in ``num_results_arg_name``) when it is
set.
For example, a retriever like
@ -52,29 +78,51 @@ class RetrieversIntegrationTests(BaseStandardTests):
"""
params = {
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
k: v
for k, v in self.retriever_constructor_params.items()
if k != self.num_results_arg_name
}
params_3 = {**params, "k": 3}
params_3 = {**params, self.num_results_arg_name: 3}
retriever_3 = self.retriever_constructor(**params_3)
result_3 = retriever_3.invoke(self.retriever_query_example)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
params_1 = {**params, "k": 1}
params_1 = {**params, self.num_results_arg_name: 1}
retriever_1 = self.retriever_constructor(**params_1)
result_1 = retriever_1.invoke(self.retriever_query_example)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
"""Test that the invoke method accepts a k parameter, representing the number of
documents to return.
"""Test the number of results parameter in ``invoke()``.
Test that the invoke method accepts a parameter representing
the number of documents to return.
By default, the parameter is named ``, but it can be overridden by
setting the ``num_results_arg_name`` property.
.. note::
If the retriever doesn't support configuring the number of results returned
via the invoke method, this test can be skipped using a pytest ``xfail`` on
the test class:
.. code-block:: python
@pytest.mark.xfail(
reason="This retriever doesn't support setting "
"the number of results in the invoke method."
)
def test_invoke_with_k_kwarg(self) -> None:
raise NotImplementedError
.. dropdown:: Troubleshooting
If this test fails, the retriever's invoke method does not accept a k
parameter, or the retriever does not return the correct number of documents
(`k`) when it is set.
If this test fails, the retriever's invoke method does not accept a number
of results parameter, or the retriever does not return the correct number
of documents (``k`` of the one set in ``num_results_arg_name``) when it is
set.
For example, a retriever like
@ -85,11 +133,15 @@ class RetrieversIntegrationTests(BaseStandardTests):
should return 3 documents when invoked with a query.
"""
result_1 = retriever.invoke(self.retriever_query_example, k=1)
result_1 = retriever.invoke(
self.retriever_query_example, None, **{self.num_results_arg_name: 1}
)
assert len(result_1) == 1
assert all(isinstance(doc, Document) for doc in result_1)
result_3 = retriever.invoke(self.retriever_query_example, k=3)
result_3 = retriever.invoke(
self.retriever_query_example, None, **{self.num_results_arg_name: 3}
)
assert len(result_3) == 3
assert all(isinstance(doc, Document) for doc in result_3)
@ -100,8 +152,8 @@ class RetrieversIntegrationTests(BaseStandardTests):
.. dropdown:: Troubleshooting
If this test fails, the retriever's invoke method does not return a list of
`langchain_core.document.Document` objects. Please confirm that your
`_get_relevant_documents` method returns a list of `Document` objects.
``langchain_core.document.Document`` objects. Please confirm that your
``_get_relevant_documents`` method returns a list of ``Document`` objects.
"""
result = retriever.invoke(self.retriever_query_example)