mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 06:13:36 +00:00
Merge branch 'master' into wip-v0.4
This commit is contained in:
commit
cbf4c0e565
3
.github/workflows/_release.yml
vendored
3
.github/workflows/_release.yml
vendored
@ -388,11 +388,12 @@ jobs:
|
||||
- name: Test against ${{ matrix.partner }}
|
||||
if: startsWith(inputs.working-directory, 'libs/core')
|
||||
run: |
|
||||
# Identify latest tag
|
||||
# Identify latest tag, excluding pre-releases
|
||||
LATEST_PACKAGE_TAG="$(
|
||||
git ls-remote --tags origin "langchain-${{ matrix.partner }}*" \
|
||||
| awk '{print $2}' \
|
||||
| sed 's|refs/tags/||' \
|
||||
| grep -Ev '==[^=]*(\.?dev[0-9]*|\.?rc[0-9]*)$' \
|
||||
| sort -Vr \
|
||||
| head -n 1
|
||||
)"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -29,8 +29,8 @@
|
||||
" Please refer to the instructions in:\n",
|
||||
" [www.jaguardb.com](http://www.jaguardb.com)\n",
|
||||
" For quick setup in docker environment:\n",
|
||||
" docker pull jaguardb/jaguardb_with_http\n",
|
||||
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb_with_http jaguardb/jaguardb_with_http\n",
|
||||
" docker pull jaguardb/jaguardb\n",
|
||||
" docker run -d -p 8888:8888 -p 8080:8080 --name jaguardb jaguardb/jaguardb\n",
|
||||
"\n",
|
||||
"2. You must install the http client package for JaguarDB:\n",
|
||||
" ```\n",
|
||||
|
@ -666,6 +666,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
converted_generations.append(chat_gen)
|
||||
else:
|
||||
# Already a ChatGeneration or other expected type
|
||||
if hasattr(gen, "message") and isinstance(gen.message, AIMessage):
|
||||
# We zero out cost on cache hits
|
||||
gen.message = gen.message.model_copy(
|
||||
update={
|
||||
"usage_metadata": {
|
||||
**(gen.message.usage_metadata or {}),
|
||||
"total_cost": 0,
|
||||
}
|
||||
}
|
||||
)
|
||||
converted_generations.append(gen)
|
||||
return converted_generations
|
||||
|
||||
|
@ -458,3 +458,23 @@ def test_cleanup_serialized() -> None:
|
||||
"name": "CustomChat",
|
||||
"type": "constructor",
|
||||
}
|
||||
|
||||
|
||||
def test_token_costs_are_zeroed_out() -> None:
|
||||
# We zero-out token costs for cache hits
|
||||
local_cache = InMemoryCache()
|
||||
messages = [
|
||||
AIMessage(
|
||||
content="Hello, how are you?",
|
||||
usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
|
||||
),
|
||||
]
|
||||
model = GenericFakeChatModel(messages=iter(messages), cache=local_cache)
|
||||
first_response = model.invoke("Hello")
|
||||
assert isinstance(first_response, AIMessage)
|
||||
assert first_response.usage_metadata
|
||||
|
||||
second_response = model.invoke("Hello")
|
||||
assert isinstance(second_response, AIMessage)
|
||||
assert second_response.usage_metadata
|
||||
assert second_response.usage_metadata["total_cost"] == 0 # type: ignore[typeddict-item]
|
||||
|
0
libs/langchain_v1/langchain/_internal/__init__.py
Normal file
0
libs/langchain_v1/langchain/_internal/__init__.py
Normal file
39
libs/langchain_v1/langchain/_internal/_documents.py
Normal file
39
libs/langchain_v1/langchain/_internal/_documents.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Internal document utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
def format_document_xml(doc: Document) -> str:
|
||||
"""Format a document as XML-like structure for LLM consumption.
|
||||
|
||||
Args:
|
||||
doc: Document to format
|
||||
|
||||
Returns:
|
||||
Document wrapped in XML tags:
|
||||
<document>
|
||||
<id>...</id>
|
||||
<content>...</content>
|
||||
<metadata>...</metadata>
|
||||
</document>
|
||||
|
||||
Note:
|
||||
Does not generate valid XML or escape special characters.
|
||||
Intended for semi-structured LLM input only.
|
||||
"""
|
||||
id_str = f"<id>{doc.id}</id>" if doc.id is not None else "<id></id>"
|
||||
metadata_str = ""
|
||||
if doc.metadata:
|
||||
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()]
|
||||
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>"
|
||||
return (
|
||||
f"<document>{id_str}"
|
||||
f"<content>{doc.page_content}</content>"
|
||||
f"{metadata_str}"
|
||||
f"</document>"
|
||||
)
|
36
libs/langchain_v1/langchain/_internal/_lazy_import.py
Normal file
36
libs/langchain_v1/langchain/_internal/_lazy_import.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Lazy import utilities."""
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Union
|
||||
|
||||
|
||||
def import_attr(
|
||||
attr_name: str,
|
||||
module_name: Union[str, None],
|
||||
package: Union[str, None],
|
||||
) -> object:
|
||||
"""Import an attribute from a module located in a package.
|
||||
|
||||
This utility function is used in custom __getattr__ methods within __init__.py
|
||||
files to dynamically import attributes.
|
||||
|
||||
Args:
|
||||
attr_name: The name of the attribute to import.
|
||||
module_name: The name of the module to import from. If None, the attribute
|
||||
is imported from the package itself.
|
||||
package: The name of the package where the module is located.
|
||||
"""
|
||||
if module_name == "__module__" or module_name is None:
|
||||
try:
|
||||
result = import_module(f".{attr_name}", package=package)
|
||||
except ModuleNotFoundError:
|
||||
msg = f"module '{package!r}' has no attribute {attr_name!r}"
|
||||
raise AttributeError(msg) from None
|
||||
else:
|
||||
try:
|
||||
module = import_module(f".{module_name}", package=package)
|
||||
except ModuleNotFoundError as err:
|
||||
msg = f"module '{package!r}.{module_name!r}' not found ({err})"
|
||||
raise ImportError(msg) from None
|
||||
result = getattr(module, attr_name)
|
||||
return result
|
166
libs/langchain_v1/langchain/_internal/_prompts.py
Normal file
166
libs/langchain_v1/langchain/_internal/_prompts.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Internal prompt resolution utilities.
|
||||
|
||||
This module provides utilities for resolving different types of prompt specifications
|
||||
into standardized message formats for language models. It supports both synchronous
|
||||
and asynchronous prompt resolution with automatic detection of callable types.
|
||||
|
||||
The module is designed to handle common prompt patterns across LangChain components,
|
||||
particularly for summarization chains and other document processing workflows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain._internal._typing import ContextT, StateT
|
||||
|
||||
|
||||
def resolve_prompt(
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
default_user_content: str,
|
||||
default_system_content: str,
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Resolve a prompt specification into a list of messages.
|
||||
|
||||
Handles prompt resolution across different strategies. Supports callable functions,
|
||||
string system messages, and None for default behavior.
|
||||
|
||||
Args:
|
||||
prompt: The prompt specification to resolve. Can be:
|
||||
- Callable: Function taking (state, runtime) returning message list.
|
||||
- str: A system message string.
|
||||
- None: Use the provided default system message.
|
||||
state: Current state, passed to callable prompts.
|
||||
runtime: LangGraph runtime instance, passed to callable prompts.
|
||||
default_user_content: User content to include (e.g., document text).
|
||||
default_system_content: Default system message when prompt is None.
|
||||
|
||||
Returns:
|
||||
List of message dictionaries for language models, typically containing
|
||||
a system message and user message with content.
|
||||
|
||||
Raises:
|
||||
TypeError: If prompt type is not str, None, or callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
def custom_prompt(state, runtime):
|
||||
return [{"role": "system", "content": "Custom"}]
|
||||
|
||||
messages = resolve_prompt(custom_prompt, state, runtime, "content", "default")
|
||||
messages = resolve_prompt("Custom system", state, runtime, "content", "default")
|
||||
messages = resolve_prompt(None, state, runtime, "content", "Default")
|
||||
```
|
||||
|
||||
Note:
|
||||
Callable prompts have full control over message structure and content
|
||||
parameter is ignored. String/None prompts create standard system + user
|
||||
structure.
|
||||
"""
|
||||
if callable(prompt):
|
||||
return prompt(state, runtime)
|
||||
if isinstance(prompt, str):
|
||||
system_msg = prompt
|
||||
elif prompt is None:
|
||||
system_msg = default_system_content
|
||||
else:
|
||||
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
|
||||
raise TypeError(msg)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": default_user_content},
|
||||
]
|
||||
|
||||
|
||||
async def aresolve_prompt(
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
Callable[
|
||||
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]]
|
||||
],
|
||||
],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
default_user_content: str,
|
||||
default_system_content: str,
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Async version of resolve_prompt supporting both sync and async callables.
|
||||
|
||||
Handles prompt resolution across different strategies. Supports sync/async callable
|
||||
functions, string system messages, and None for default behavior.
|
||||
|
||||
Args:
|
||||
prompt: The prompt specification to resolve. Can be:
|
||||
- Callable (sync): Function taking (state, runtime) returning message list.
|
||||
- Callable (async): Async function taking (state, runtime) returning
|
||||
awaitable message list.
|
||||
- str: A system message string.
|
||||
- None: Use the provided default system message.
|
||||
state: Current state, passed to callable prompts.
|
||||
runtime: LangGraph runtime instance, passed to callable prompts.
|
||||
default_user_content: User content to include (e.g., document text).
|
||||
default_system_content: Default system message when prompt is None.
|
||||
|
||||
Returns:
|
||||
List of message dictionaries for language models, typically containing
|
||||
a system message and user message with content.
|
||||
|
||||
Raises:
|
||||
TypeError: If prompt type is not str, None, or callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def async_prompt(state, runtime):
|
||||
return [{"role": "system", "content": "Async"}]
|
||||
|
||||
def sync_prompt(state, runtime):
|
||||
return [{"role": "system", "content": "Sync"}]
|
||||
|
||||
messages = await aresolve_prompt(
|
||||
async_prompt, state, runtime, "content", "default"
|
||||
)
|
||||
messages = await aresolve_prompt(
|
||||
sync_prompt, state, runtime, "content", "default"
|
||||
)
|
||||
messages = await aresolve_prompt("Custom", state, runtime, "content", "default")
|
||||
```
|
||||
|
||||
Note:
|
||||
Callable prompts have full control over message structure and content
|
||||
parameter is ignored. Automatically detects and handles async
|
||||
callables.
|
||||
"""
|
||||
if callable(prompt):
|
||||
result = prompt(state, runtime)
|
||||
# Check if the result is awaitable (async function)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
if isinstance(prompt, str):
|
||||
system_msg = prompt
|
||||
elif prompt is None:
|
||||
system_msg = default_system_content
|
||||
else:
|
||||
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable."
|
||||
raise TypeError(msg)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": default_user_content},
|
||||
]
|
65
libs/langchain_v1/langchain/_internal/_typing.py
Normal file
65
libs/langchain_v1/langchain/_internal/_typing.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Private typing utilities for langchain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union
|
||||
|
||||
from langgraph.graph._node import StateNode
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import Field
|
||||
|
||||
|
||||
class TypedDictLikeV1(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts.
|
||||
|
||||
Version 1: using `ClassVar` for keys.
|
||||
"""
|
||||
|
||||
__required_keys__: ClassVar[frozenset[str]]
|
||||
__optional_keys__: ClassVar[frozenset[str]]
|
||||
|
||||
|
||||
class TypedDictLikeV2(Protocol):
|
||||
"""Protocol to represent types that behave like TypedDicts.
|
||||
|
||||
Version 2: not using `ClassVar` for keys.
|
||||
"""
|
||||
|
||||
__required_keys__: frozenset[str]
|
||||
__optional_keys__: frozenset[str]
|
||||
|
||||
|
||||
class DataclassLike(Protocol):
|
||||
"""Protocol to represent types that behave like dataclasses.
|
||||
|
||||
Inspired by the private _DataclassT from dataclasses that uses a similar
|
||||
protocol as a bound.
|
||||
"""
|
||||
|
||||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
||||
|
||||
|
||||
StateLike: TypeAlias = Union[TypedDictLikeV1, TypedDictLikeV2, DataclassLike, BaseModel]
|
||||
"""Type alias for state-like types.
|
||||
|
||||
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`.
|
||||
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in
|
||||
type checking.
|
||||
"""
|
||||
|
||||
StateT = TypeVar("StateT", bound=StateLike)
|
||||
"""Type variable used to represent the state in a graph."""
|
||||
|
||||
ContextT = TypeVar("ContextT", bound=Union[StateLike, None])
|
||||
"""Type variable for context types."""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ContextT",
|
||||
"StateLike",
|
||||
"StateNode",
|
||||
"StateT",
|
||||
]
|
7
libs/langchain_v1/langchain/_internal/_utils.py
Normal file
7
libs/langchain_v1/langchain/_internal/_utils.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Re-exporting internal utilities from LangGraph for internal use in LangChain.
|
||||
# A different wrapper needs to be created for this purpose in LangChain.
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
|
||||
__all__ = [
|
||||
"RunnableCallable",
|
||||
]
|
9
libs/langchain_v1/langchain/chains/__init__.py
Normal file
9
libs/langchain_v1/langchain/chains/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from langchain.chains.documents import (
|
||||
create_map_reduce_chain,
|
||||
create_stuff_documents_chain,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_map_reduce_chain",
|
||||
"create_stuff_documents_chain",
|
||||
]
|
17
libs/langchain_v1/langchain/chains/documents/__init__.py
Normal file
17
libs/langchain_v1/langchain/chains/documents/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Document extraction chains.
|
||||
|
||||
This module provides different strategies for extracting information from collections
|
||||
of documents using LangGraph and modern language models.
|
||||
|
||||
Available Strategies:
|
||||
- Stuff: Processes all documents together in a single context window
|
||||
- Map-Reduce: Processes documents in parallel (map), then combines results (reduce)
|
||||
"""
|
||||
|
||||
from langchain.chains.documents.map_reduce import create_map_reduce_chain
|
||||
from langchain.chains.documents.stuff import create_stuff_documents_chain
|
||||
|
||||
__all__ = [
|
||||
"create_map_reduce_chain",
|
||||
"create_stuff_documents_chain",
|
||||
]
|
586
libs/langchain_v1/langchain/chains/documents/map_reduce.py
Normal file
586
libs/langchain_v1/langchain/chains/documents/map_reduce.py
Normal file
@ -0,0 +1,586 @@
|
||||
"""Map-Reduce Extraction Implementation using LangGraph Send API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain._internal._documents import format_document_xml
|
||||
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
|
||||
from langchain._internal._typing import ContextT, StateNode
|
||||
from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
# Pycharm is unable to identify that AIMessage is used in the cast below
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
MessageLikeRepresentation,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExtractionResult(TypedDict):
|
||||
"""Result from processing a document or group of documents."""
|
||||
|
||||
indexes: list[int]
|
||||
"""Document indexes that contributed to this result."""
|
||||
result: Any
|
||||
"""Extracted result from the document(s)."""
|
||||
|
||||
|
||||
class MapReduceState(TypedDict):
|
||||
"""State for map-reduce extraction chain.
|
||||
|
||||
This state tracks the map-reduce process where documents are processed
|
||||
in parallel during the map phase, then combined in the reduce phase.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
map_results: Annotated[list[ExtractionResult], operator.add]
|
||||
"""Individual results from the map phase."""
|
||||
result: NotRequired[Any]
|
||||
"""Final combined result from the reduce phase if applicable."""
|
||||
|
||||
|
||||
# The payload for the map phase is a list of documents and their indexes.
|
||||
# The current implementation only supports a single document per map operation,
|
||||
# but the structure allows for future expansion to process a group of documents.
|
||||
# A user would provide an input split function that returns groups of documents
|
||||
# to process together, if desired.
|
||||
class MapState(TypedDict):
|
||||
"""State for individual map operations."""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process in map phase."""
|
||||
indexes: list[int]
|
||||
"""List of indexes of the documents in the original list."""
|
||||
|
||||
|
||||
class InputSchema(TypedDict):
|
||||
"""Input schema for the map-reduce extraction chain.
|
||||
|
||||
Defines the expected input format when invoking the extraction chain.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
|
||||
|
||||
class OutputSchema(TypedDict):
|
||||
"""Output schema for the map-reduce extraction chain.
|
||||
|
||||
Defines the format of the final result returned by the chain.
|
||||
"""
|
||||
|
||||
map_results: list[ExtractionResult]
|
||||
"""List of individual extraction results from the map phase."""
|
||||
|
||||
result: Any
|
||||
"""Final combined result from all documents."""
|
||||
|
||||
|
||||
class MapReduceNodeUpdate(TypedDict):
|
||||
"""Update returned by map-reduce nodes."""
|
||||
|
||||
map_results: NotRequired[list[ExtractionResult]]
|
||||
"""Updated results after map phase."""
|
||||
result: NotRequired[Any]
|
||||
"""Final result after reduce phase."""
|
||||
|
||||
|
||||
class _MapReduceExtractor(Generic[ContextT]):
|
||||
"""Map-reduce extraction implementation using LangGraph Send API.
|
||||
|
||||
This implementation uses a language model to process documents through up
|
||||
to two phases:
|
||||
|
||||
1. **Map Phase**: Each document is processed independently by the LLM using
|
||||
the configured map_prompt to generate individual extraction results.
|
||||
2. **Reduce Phase (Optional)**: Individual results can optionally be
|
||||
combined using either:
|
||||
- The default LLM-based reducer with the configured reduce_prompt
|
||||
- A custom reducer function (which can be non-LLM based)
|
||||
- Skipped entirely by setting reduce=None
|
||||
|
||||
The map phase processes documents in parallel for efficiency, making this approach
|
||||
well-suited for large document collections. The reduce phase is flexible and can be
|
||||
customized or omitted based on your specific requirements.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
map_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[MapState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
reduce_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[MapReduceState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
reduce: Union[
|
||||
Literal["default_reducer"],
|
||||
None,
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> None:
|
||||
"""Initialize the MapReduceExtractor.
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
map_prompt: Prompt for individual document processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce_prompt: Prompt for combining results. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce: Controls the reduce behavior. Can be:
|
||||
- "default_reducer": Use the default LLM-based reduce step
|
||||
- None: Skip the reduce step entirely
|
||||
- Callable: Custom reduce function (sync or async)
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
"""
|
||||
if (reduce is None or callable(reduce)) and reduce_prompt is not None:
|
||||
msg = (
|
||||
"reduce_prompt must be None when reduce is None or a custom "
|
||||
"callable. Custom reduce functions handle their own logic and "
|
||||
"should not use reduce_prompt."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.response_format = response_format
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.map_prompt = map_prompt
|
||||
self.reduce_prompt = reduce_prompt
|
||||
self.reduce = reduce
|
||||
self.context_schema = context_schema
|
||||
|
||||
def _get_map_prompt(
|
||||
self, state: MapState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for processing documents."""
|
||||
documents = state["documents"]
|
||||
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
default_system = (
|
||||
"You are a helpful assistant that processes documents. "
|
||||
"Please process the following documents and provide a result."
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.map_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
async def _aget_map_prompt(
|
||||
self, state: MapState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for processing documents in the map phase.
|
||||
|
||||
Async version.
|
||||
"""
|
||||
documents = state["documents"]
|
||||
user_content = "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
default_system = (
|
||||
"You are a helpful assistant that processes documents. "
|
||||
"Please process the following documents and provide a result."
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.map_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
def _get_reduce_prompt(
|
||||
self, state: MapReduceState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for combining individual results.
|
||||
|
||||
Combines map results in the reduce phase.
|
||||
"""
|
||||
map_results = state.get("map_results", [])
|
||||
if not map_results:
|
||||
msg = (
|
||||
"Internal programming error: Results must exist when reducing. "
|
||||
"This indicates that the reduce node was reached without "
|
||||
"first processing the map nodes, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
results_text = "\n\n".join(
|
||||
f"Result {i + 1} (from documents "
|
||||
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
|
||||
for i, result in enumerate(map_results)
|
||||
)
|
||||
user_content = (
|
||||
f"Please combine the following results into a single, "
|
||||
f"comprehensive result:\n\n{results_text}"
|
||||
)
|
||||
default_system = (
|
||||
"You are a helpful assistant that combines multiple results. "
|
||||
"Given several individual results, create a single comprehensive "
|
||||
"result that captures the key information from all inputs while "
|
||||
"maintaining conciseness and coherence."
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.reduce_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
async def _aget_reduce_prompt(
|
||||
self, state: MapReduceState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the LLM prompt for combining individual results.
|
||||
|
||||
Async version of reduce phase.
|
||||
"""
|
||||
map_results = state.get("map_results", [])
|
||||
if not map_results:
|
||||
msg = (
|
||||
"Internal programming error: Results must exist when reducing. "
|
||||
"This indicates that the reduce node was reached without "
|
||||
"first processing the map nodes, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
results_text = "\n\n".join(
|
||||
f"Result {i + 1} (from documents "
|
||||
f"{', '.join(map(str, result['indexes']))}):\n{result['result']}"
|
||||
for i, result in enumerate(map_results)
|
||||
)
|
||||
user_content = (
|
||||
f"Please combine the following results into a single, "
|
||||
f"comprehensive result:\n\n{results_text}"
|
||||
)
|
||||
default_system = (
|
||||
"You are a helpful assistant that combines multiple results. "
|
||||
"Given several individual results, create a single comprehensive "
|
||||
"result that captures the key information from all inputs while "
|
||||
"maintaining conciseness and coherence."
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.reduce_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_system,
|
||||
)
|
||||
|
||||
def create_map_node(self) -> RunnableCallable:
|
||||
"""Create a LangGraph node that processes individual documents using the LLM."""
|
||||
|
||||
def _map_node(
|
||||
state: MapState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> dict[str, list[ExtractionResult]]:
|
||||
prompt = self._get_map_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
extraction_result: ExtractionResult = {
|
||||
"indexes": state["indexes"],
|
||||
"result": result,
|
||||
}
|
||||
return {"map_results": [extraction_result]}
|
||||
|
||||
async def _amap_node(
|
||||
state: MapState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> dict[str, list[ExtractionResult]]:
|
||||
prompt = await self._aget_map_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
extraction_result: ExtractionResult = {
|
||||
"indexes": state["indexes"],
|
||||
"result": result,
|
||||
}
|
||||
return {"map_results": [extraction_result]}
|
||||
|
||||
return RunnableCallable(
|
||||
_map_node,
|
||||
_amap_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def create_reduce_node(self) -> RunnableCallable:
|
||||
"""Create a LangGraph node that combines individual results using the LLM."""
|
||||
|
||||
def _reduce_node(
|
||||
state: MapReduceState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> MapReduceNodeUpdate:
|
||||
prompt = self._get_reduce_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
async def _areduce_node(
|
||||
state: MapReduceState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> MapReduceNodeUpdate:
|
||||
prompt = await self._aget_reduce_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
return RunnableCallable(
|
||||
_reduce_node,
|
||||
_areduce_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def continue_to_map(self, state: MapReduceState) -> list[Send]:
|
||||
"""Generate Send objects for parallel map operations."""
|
||||
return [
|
||||
Send("map_process", {"documents": [doc], "indexes": [i]})
|
||||
for i, doc in enumerate(state["documents"])
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Build and compile the LangGraph for map-reduce summarization."""
|
||||
builder = StateGraph(
|
||||
MapReduceState,
|
||||
context_schema=self.context_schema,
|
||||
input_schema=InputSchema,
|
||||
output_schema=OutputSchema,
|
||||
)
|
||||
|
||||
builder.add_node("map_process", self.create_map_node())
|
||||
|
||||
builder.add_edge(START, "continue_to_map")
|
||||
# Add-conditional edges doesn't explicitly type Send
|
||||
builder.add_conditional_edges(
|
||||
"continue_to_map",
|
||||
self.continue_to_map, # type: ignore[arg-type]
|
||||
["map_process"],
|
||||
)
|
||||
|
||||
if self.reduce is None:
|
||||
builder.add_edge("map_process", END)
|
||||
elif self.reduce == "default_reducer":
|
||||
builder.add_node("reduce_process", self.create_reduce_node())
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
builder.add_edge("reduce_process", END)
|
||||
else:
|
||||
reduce_node = cast("StateNode", self.reduce)
|
||||
# The type is ignored here. Requires parameterizing with generics.
|
||||
builder.add_node("reduce_process", reduce_node) # type: ignore[arg-type]
|
||||
builder.add_edge("map_process", "reduce_process")
|
||||
builder.add_edge("reduce_process", END)
|
||||
|
||||
return builder
|
||||
|
||||
|
||||
def create_map_reduce_chain(
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
map_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[MapState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
reduce_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[MapReduceState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
reduce: Union[
|
||||
Literal["default_reducer"],
|
||||
None,
|
||||
StateNode,
|
||||
] = "default_reducer",
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> StateGraph[MapReduceState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a map-reduce document extraction chain.
|
||||
|
||||
This implementation uses a language model to extract information from documents
|
||||
through a flexible approach that efficiently handles large document collections
|
||||
by processing documents in parallel.
|
||||
|
||||
**Processing Flow:**
|
||||
1. **Map Phase**: Each document is independently processed by the LLM
|
||||
using the map_prompt to extract relevant information and generate
|
||||
individual results.
|
||||
2. **Reduce Phase (Optional)**: Individual extraction results can
|
||||
optionally be combined using:
|
||||
- The default LLM-based reducer with reduce_prompt (default behavior)
|
||||
- A custom reducer function (can be non-LLM based)
|
||||
- Skipped entirely by setting reduce=None
|
||||
3. **Output**: Returns the individual map results and optionally the final
|
||||
combined result.
|
||||
|
||||
Example:
|
||||
>>> from langchain_anthropic import ChatAnthropic
|
||||
>>> from langchain_core.documents import Document
|
||||
>>>
|
||||
>>> model = ChatAnthropic(
|
||||
... model="claude-sonnet-4-20250514",
|
||||
... temperature=0,
|
||||
... max_tokens=62_000,
|
||||
... timeout=None,
|
||||
... max_retries=2,
|
||||
... )
|
||||
>>> builder = create_map_reduce_chain(model)
|
||||
>>> chain = builder.compile()
|
||||
>>> docs = [
|
||||
... Document(page_content="First document content..."),
|
||||
... Document(page_content="Second document content..."),
|
||||
... Document(page_content="Third document content..."),
|
||||
... ]
|
||||
>>> result = chain.invoke({"documents": docs})
|
||||
>>> print(result["result"])
|
||||
|
||||
Example with string model:
|
||||
>>> builder = create_map_reduce_chain("anthropic:claude-sonnet-4-20250514")
|
||||
>>> chain = builder.compile()
|
||||
>>> result = chain.invoke({"documents": docs})
|
||||
>>> print(result["result"])
|
||||
|
||||
Example with structured output:
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ExtractionModel(BaseModel):
|
||||
title: str
|
||||
key_points: list[str]
|
||||
conclusion: str
|
||||
|
||||
builder = create_map_reduce_chain(
|
||||
model,
|
||||
response_format=ExtractionModel
|
||||
)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"].title) # Access structured fields
|
||||
```
|
||||
|
||||
Example skipping the reduce phase:
|
||||
```python
|
||||
# Only perform map phase, skip combining results
|
||||
builder = create_map_reduce_chain(model, reduce=None)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
# result["result"] will be None, only map_results are available
|
||||
for map_result in result["map_results"]:
|
||||
print(f"Document {map_result['indexes'][0]}: {map_result['result']}")
|
||||
```
|
||||
|
||||
Example with custom reducer:
|
||||
```python
|
||||
def custom_aggregator(state, runtime):
|
||||
# Custom non-LLM based reduction logic
|
||||
map_results = state["map_results"]
|
||||
combined_text = " | ".join(r["result"] for r in map_results)
|
||||
word_count = len(combined_text.split())
|
||||
return {
|
||||
"result": f"Combined {len(map_results)} results with "
|
||||
f"{word_count} total words"
|
||||
}
|
||||
|
||||
builder = create_map_reduce_chain(model, reduce=custom_aggregator)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"]) # Custom aggregated result
|
||||
```
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
map_prompt: Prompt for individual document processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce_prompt: Prompt for combining results. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
reduce: Controls the reduce behavior. Can be:
|
||||
- "default_reducer": Use the default LLM-based reduce step
|
||||
- None: Skip the reduce step entirely
|
||||
- Callable: Custom reduce function (sync or async)
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
|
||||
Returns:
|
||||
A LangGraph that can be invoked with documents to get map-reduce
|
||||
extraction results.
|
||||
|
||||
Note:
|
||||
This implementation is well-suited for large document collections as it
|
||||
processes documents in parallel during the map phase. The Send API enables
|
||||
efficient parallelization while maintaining clean state management.
|
||||
"""
|
||||
extractor = _MapReduceExtractor(
|
||||
model,
|
||||
map_prompt=map_prompt,
|
||||
reduce_prompt=reduce_prompt,
|
||||
reduce=reduce,
|
||||
context_schema=context_schema,
|
||||
response_format=response_format,
|
||||
)
|
||||
return extractor.build()
|
||||
|
||||
|
||||
__all__ = ["create_map_reduce_chain"]
|
473
libs/langchain_v1/langchain/chains/documents/stuff.py
Normal file
473
libs/langchain_v1/langchain/chains/documents/stuff.py
Normal file
@ -0,0 +1,473 @@
|
||||
"""Stuff documents chain for processing documents by putting them all in context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
# Used not only for type checking, but is fetched at runtime by Pydantic.
|
||||
from langchain_core.documents import Document as Document # noqa: TC002
|
||||
from langgraph.graph import START, StateGraph
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain._internal._documents import format_document_xml
|
||||
from langchain._internal._prompts import aresolve_prompt, resolve_prompt
|
||||
from langchain._internal._typing import ContextT
|
||||
from langchain._internal._utils import RunnableCallable
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
# Used for type checking, but IDEs may not recognize it inside the cast.
|
||||
from langchain_core.messages import AIMessage as AIMessage
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Default system prompts
|
||||
DEFAULT_INIT_PROMPT = (
|
||||
"You are a helpful assistant that summarizes text. "
|
||||
"Please provide a concise summary of the documents "
|
||||
"provided by the user."
|
||||
)
|
||||
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT = (
|
||||
"You are a helpful assistant that extracts structured information from documents. "
|
||||
"Use the provided content and optional question to generate your output, formatted "
|
||||
"according to the predefined schema."
|
||||
)
|
||||
|
||||
DEFAULT_REFINE_PROMPT = (
|
||||
"You are a helpful assistant that refines summaries. "
|
||||
"Given an existing summary and new context, produce a refined summary "
|
||||
"that incorporates the new information while maintaining conciseness."
|
||||
)
|
||||
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT = (
|
||||
"You are a helpful assistant refining structured information extracted "
|
||||
"from documents. "
|
||||
"You are given a previous result and new document context. "
|
||||
"Update the output to reflect the new context, staying consistent with "
|
||||
"the expected schema."
|
||||
)
|
||||
|
||||
|
||||
def _format_documents_content(documents: list[Document]) -> str:
|
||||
"""Format documents into content string.
|
||||
|
||||
Args:
|
||||
documents: List of documents to format.
|
||||
|
||||
Returns:
|
||||
Formatted document content string.
|
||||
"""
|
||||
return "\n\n".join(format_document_xml(doc) for doc in documents)
|
||||
|
||||
|
||||
class ExtractionState(TypedDict):
|
||||
"""State for extraction chain.
|
||||
|
||||
This state tracks the extraction process where documents
|
||||
are processed in batch, with the result being refined if needed.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
result: NotRequired[Any]
|
||||
"""Current result, refined with each document."""
|
||||
|
||||
|
||||
class InputSchema(TypedDict):
|
||||
"""Input schema for the extraction chain.
|
||||
|
||||
Defines the expected input format when invoking the extraction chain.
|
||||
"""
|
||||
|
||||
documents: list[Document]
|
||||
"""List of documents to process."""
|
||||
result: NotRequired[Any]
|
||||
"""Existing result to refine (optional)."""
|
||||
|
||||
|
||||
class OutputSchema(TypedDict):
|
||||
"""Output schema for the extraction chain.
|
||||
|
||||
Defines the format of the final result returned by the chain.
|
||||
"""
|
||||
|
||||
result: Any
|
||||
"""Result from processing the documents."""
|
||||
|
||||
|
||||
class ExtractionNodeUpdate(TypedDict):
|
||||
"""Update returned by processing nodes."""
|
||||
|
||||
result: NotRequired[Any]
|
||||
"""Updated result after processing a document."""
|
||||
|
||||
|
||||
class _Extractor(Generic[ContextT]):
|
||||
"""Stuff documents chain implementation.
|
||||
|
||||
This chain works by putting all the documents in the batch into the context
|
||||
window of the language model. It processes all documents together in a single
|
||||
request for extracting information or summaries. Can refine existing results
|
||||
when provided.
|
||||
|
||||
Important: This chain does not attempt to control for the size of the context
|
||||
window of the LLM. Ensure your documents fit within the model's context limits.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[ExtractionState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
refine_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[
|
||||
[ExtractionState, Runtime[ContextT]],
|
||||
list[MessageLikeRepresentation],
|
||||
],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> None:
|
||||
"""Initialize the Extractor.
|
||||
|
||||
Args:
|
||||
model: The language model either a chat model instance
|
||||
(e.g., `ChatAnthropic()`) or string identifier
|
||||
(e.g., `"anthropic:claude-sonnet-4-20250514"`)
|
||||
prompt: Prompt for initial processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
refine_prompt: Prompt for refinement steps. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
"""
|
||||
self.response_format = response_format
|
||||
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
self.model = (
|
||||
model.with_structured_output(response_format) if response_format else model
|
||||
)
|
||||
self.initial_prompt = prompt
|
||||
self.refine_prompt = refine_prompt
|
||||
self.context_schema = context_schema
|
||||
|
||||
def _get_initial_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the initial extraction prompt."""
|
||||
user_content = _format_documents_content(state["documents"])
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.initial_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
async def _aget_initial_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the initial extraction prompt (async version)."""
|
||||
user_content = _format_documents_content(state["documents"])
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_INIT_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_INIT_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.initial_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
def _get_refine_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the refinement prompt."""
|
||||
# Result should be guaranteed to exist at refinement stage
|
||||
if "result" not in state or state["result"] == "":
|
||||
msg = (
|
||||
"Internal programming error: Result must exist when refining. "
|
||||
"This indicates that the refinement node was reached without "
|
||||
"first processing the initial result node, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_context = _format_documents_content(state["documents"])
|
||||
|
||||
user_content = (
|
||||
f"Previous result:\n{state['result']}\n\n"
|
||||
f"New context:\n{new_context}\n\n"
|
||||
f"Please provide a refined result."
|
||||
)
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return resolve_prompt(
|
||||
self.refine_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
async def _aget_refine_prompt(
|
||||
self, state: ExtractionState, runtime: Runtime[ContextT]
|
||||
) -> list[MessageLikeRepresentation]:
|
||||
"""Generate the refinement prompt (async version)."""
|
||||
# Result should be guaranteed to exist at refinement stage
|
||||
if "result" not in state or state["result"] == "":
|
||||
msg = (
|
||||
"Internal programming error: Result must exist when refining. "
|
||||
"This indicates that the refinement node was reached without "
|
||||
"first processing the initial result node, which violates "
|
||||
"the expected graph execution order."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_context = _format_documents_content(state["documents"])
|
||||
|
||||
user_content = (
|
||||
f"Previous result:\n{state['result']}\n\n"
|
||||
f"New context:\n{new_context}\n\n"
|
||||
f"Please provide a refined result."
|
||||
)
|
||||
|
||||
# Choose default prompt based on structured output format
|
||||
default_prompt = (
|
||||
DEFAULT_STRUCTURED_REFINE_PROMPT
|
||||
if self.response_format
|
||||
else DEFAULT_REFINE_PROMPT
|
||||
)
|
||||
|
||||
return await aresolve_prompt(
|
||||
self.refine_prompt,
|
||||
state,
|
||||
runtime,
|
||||
user_content,
|
||||
default_prompt,
|
||||
)
|
||||
|
||||
def create_document_processor_node(self) -> RunnableCallable:
|
||||
"""Create the main document processing node.
|
||||
|
||||
The node handles both initial processing and refinement of results.
|
||||
|
||||
Refinement is done by providing the existing result and new context.
|
||||
|
||||
If the workflow is run with a checkpointer enabled, the result will be
|
||||
persisted and available for a given thread id.
|
||||
"""
|
||||
|
||||
def _process_node(
|
||||
state: ExtractionState, runtime: Runtime[ContextT], config: RunnableConfig
|
||||
) -> ExtractionNodeUpdate:
|
||||
# Handle empty document list
|
||||
if not state["documents"]:
|
||||
return {}
|
||||
|
||||
# Determine if this is initial processing or refinement
|
||||
if "result" not in state or state["result"] == "":
|
||||
# Initial processing
|
||||
prompt = self._get_initial_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
# Refinement
|
||||
prompt = self._get_refine_prompt(state, runtime)
|
||||
response = cast("AIMessage", self.model.invoke(prompt, config=config))
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
async def _aprocess_node(
|
||||
state: ExtractionState,
|
||||
runtime: Runtime[ContextT],
|
||||
config: RunnableConfig,
|
||||
) -> ExtractionNodeUpdate:
|
||||
# Handle empty document list
|
||||
if not state["documents"]:
|
||||
return {}
|
||||
|
||||
# Determine if this is initial processing or refinement
|
||||
if "result" not in state or state["result"] == "":
|
||||
# Initial processing
|
||||
prompt = await self._aget_initial_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
# Refinement
|
||||
prompt = await self._aget_refine_prompt(state, runtime)
|
||||
response = cast(
|
||||
"AIMessage", await self.model.ainvoke(prompt, config=config)
|
||||
)
|
||||
result = response if self.response_format else response.text()
|
||||
return {"result": result}
|
||||
|
||||
return RunnableCallable(
|
||||
_process_node,
|
||||
_aprocess_node,
|
||||
trace=False,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Build and compile the LangGraph for batch document extraction."""
|
||||
builder = StateGraph(
|
||||
ExtractionState,
|
||||
context_schema=self.context_schema,
|
||||
input_schema=InputSchema,
|
||||
output_schema=OutputSchema,
|
||||
)
|
||||
builder.add_edge(START, "process")
|
||||
builder.add_node("process", self.create_document_processor_node())
|
||||
return builder
|
||||
|
||||
|
||||
def create_stuff_documents_chain(
|
||||
model: Union[BaseChatModel, str],
|
||||
*,
|
||||
prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
refine_prompt: Union[
|
||||
str,
|
||||
None,
|
||||
Callable[[ExtractionState, Runtime[ContextT]], list[MessageLikeRepresentation]],
|
||||
] = None,
|
||||
context_schema: type[ContextT] | None = None,
|
||||
response_format: Optional[type[BaseModel]] = None,
|
||||
) -> StateGraph[ExtractionState, ContextT, InputSchema, OutputSchema]:
|
||||
"""Create a stuff documents chain for processing documents.
|
||||
|
||||
This chain works by putting all the documents in the batch into the context
|
||||
window of the language model. It processes all documents together in a single
|
||||
request for extracting information or summaries. Can refine existing results
|
||||
when provided. The default prompts are optimized for summarization tasks, but
|
||||
can be customized for other extraction tasks via the prompt parameters or
|
||||
response_format.
|
||||
|
||||
Strategy:
|
||||
1. Put all documents into the context window
|
||||
2. Process all documents together in a single request
|
||||
3. If an existing result is provided, refine it with all documents at once
|
||||
4. Return the result
|
||||
|
||||
Important:
|
||||
This chain does not attempt to control for the size of the context
|
||||
window of the LLM. Ensure your documents fit within the model's context limits.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.documents import Document
|
||||
|
||||
model = init_chat_model("anthropic:claude-sonnet-4-20250514")
|
||||
builder = create_stuff_documents_chain(model)
|
||||
chain = builder.compile()
|
||||
docs = [
|
||||
Document(page_content="First document content..."),
|
||||
Document(page_content="Second document content..."),
|
||||
Document(page_content="Third document content..."),
|
||||
]
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"])
|
||||
|
||||
# Structured summary/extraction by passing a schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Summary(BaseModel):
|
||||
title: str
|
||||
key_points: list[str]
|
||||
|
||||
builder = create_stuff_documents_chain(model, response_format=Summary)
|
||||
chain = builder.compile()
|
||||
result = chain.invoke({"documents": docs})
|
||||
print(result["result"].title) # Access structured fields
|
||||
```
|
||||
|
||||
Args:
|
||||
model: The language model for document processing.
|
||||
prompt: Prompt for initial processing. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
refine_prompt: Prompt for refinement steps. Can be:
|
||||
- str: A system message string
|
||||
- None: Use default system message
|
||||
- Callable: A function that takes (state, runtime) and returns messages
|
||||
context_schema: Optional context schema for the LangGraph runtime.
|
||||
response_format: Optional pydantic BaseModel for structured output.
|
||||
|
||||
Returns:
|
||||
A LangGraph that can be invoked with documents to extract information.
|
||||
|
||||
Note:
|
||||
This is a "stuff" documents chain that puts all documents into the context
|
||||
window and processes them together. It supports refining existing results.
|
||||
Default prompts are optimized for summarization but can be customized for
|
||||
other tasks. Important: Does not control for context window size.
|
||||
"""
|
||||
extractor = _Extractor(
|
||||
model,
|
||||
prompt=prompt,
|
||||
refine_prompt=refine_prompt,
|
||||
context_schema=context_schema,
|
||||
response_format=response_format,
|
||||
)
|
||||
return extractor.build()
|
||||
|
||||
|
||||
__all__ = ["create_stuff_documents_chain"]
|
5
libs/langchain_v1/langchain/documents/__init__.py
Normal file
5
libs/langchain_v1/langchain/documents/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from langchain_core.documents import Document
|
||||
|
||||
__all__ = [
|
||||
"Document",
|
||||
]
|
@ -9,7 +9,7 @@ requires-python = ">=3.9, <4.0"
|
||||
dependencies = [
|
||||
"langchain-core<1.0.0,>=0.3.66",
|
||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||
"langgraph>=0.5.4",
|
||||
"langgraph>=0.6.0",
|
||||
"pydantic>=2.7.4",
|
||||
]
|
||||
|
||||
@ -123,6 +123,7 @@ ignore = [
|
||||
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||
"PLC0415", # Imports should be at the top. Not always desirable
|
||||
"PLR0913", # Too many arguments in function definition
|
||||
"PLC0414", # Inconsistent with how type checkers expect to be notified of intentional re-exports
|
||||
]
|
||||
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||
|
||||
|
@ -1667,7 +1667,7 @@ requires-dist = [
|
||||
{ name = "langchain-text-splitters", editable = "../text-splitters" },
|
||||
{ name = "langchain-together", marker = "extra == 'together'" },
|
||||
{ name = "langchain-xai", marker = "extra == 'xai'" },
|
||||
{ name = "langgraph", specifier = ">=0.5.4" },
|
||||
{ name = "langgraph", specifier = ">=0.6.0" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4" },
|
||||
]
|
||||
provides-extras = ["anthropic", "openai", "azure-ai", "google-vertexai", "google-genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "aws", "deepseek", "xai", "perplexity"]
|
||||
@ -1757,7 +1757,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.70"
|
||||
version = "0.3.72"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@ -2133,7 +2133,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langgraph"
|
||||
version = "0.5.4"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@ -2143,9 +2143,9 @@ dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/99/26/f01ae40ea26f8c723b6ec186869c80cc04de801630d99943018428b46105/langgraph-0.5.4.tar.gz", hash = "sha256:ab8f6b7b9c50fd2ae35a2efb072fbbfe79500dfc18071ac4ba6f5de5fa181931", size = 443149, upload-time = "2025-07-21T18:20:55.63Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6e/12/4a30f766de571bfc319e70c2c0f4d050c0576e15c2249dc75ad122706a5d/langgraph-0.6.1.tar.gz", hash = "sha256:e4399ac5ad0b70f58fa28d6fe05a41b84c15959f270d6d1a86edab4e92ae148b", size = 449723, upload-time = "2025-07-29T20:45:28.438Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/82/15184e953234877107bad182b79c9111cb6ce6a79a97fdf36ebcaa11c0d0/langgraph-0.5.4-py3-none-any.whl", hash = "sha256:7122840225623e081be24ac30a691a24e5dac4c0361f593208f912838192d7f6", size = 143942, upload-time = "2025-07-21T18:20:54.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/90/e37e2bc19bee81f859bfd61526d977f54ec5d8e036fa091f2cfe4f19560b/langgraph-0.6.1-py3-none-any.whl", hash = "sha256:2736027faeb6cd5c0f1ab51a5345594cfcb5eb5beeb5ac1799a58fcecf4b4eae", size = 151874, upload-time = "2025-07-29T20:45:26.998Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2163,28 +2163,28 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-prebuilt"
|
||||
version = "0.5.2"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bb/11/98134c47832fbde0caf0e06f1a104577da9215c358d7854093c1d835b272/langgraph_prebuilt-0.5.2.tar.gz", hash = "sha256:2c900a5be0d6a93ea2521e0d931697cad2b646f1fcda7aa5c39d8d7539772465", size = 117808, upload-time = "2025-06-30T19:52:48.307Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/b6/d4f8e800bdfdd75486595203d5c622bba5f1098e4fd4220452c75568f2be/langgraph_prebuilt-0.6.1.tar.gz", hash = "sha256:574c409113e02d3c58157877c5ea638faa80647b259027647ab88830d7ecef00", size = 125057, upload-time = "2025-07-29T20:44:48.634Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/64/6bc45ab9e0e1112698ebff579fe21f5606ea65cd08266995a357e312a4d2/langgraph_prebuilt-0.5.2-py3-none-any.whl", hash = "sha256:1f4cd55deca49dffc3e5127eec12fcd244fc381321002f728afa88642d5ec59d", size = 23776, upload-time = "2025-06-30T19:52:47.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/df/cb4d73e99719b7ca0d42503d39dec49c67225779c6a1e689954a5604dbe6/langgraph_prebuilt-0.6.1-py3-none-any.whl", hash = "sha256:a3a970451371ec66509c6969505286a5d92132af7062d0b2b6dab08c2e27b50f", size = 28866, upload-time = "2025-07-29T20:44:47.72Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-sdk"
|
||||
version = "0.1.74"
|
||||
version = "0.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "orjson" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/f7/3807b72988f7eef5e0eb41e7e695eca50f3ed31f7cab5602db3b651c85ff/langgraph_sdk-0.1.74.tar.gz", hash = "sha256:7450e0db5b226cc2e5328ca22c5968725873630ef47c4206a30707cb25dc3ad6", size = 72190, upload-time = "2025-07-21T16:36:50.032Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2a/3e/3dc45dc7682c9940db9edaf8773d2e157397c5bd6881f6806808afd8731e/langgraph_sdk-0.2.0.tar.gz", hash = "sha256:cd8b5f6595e5571be5cbffd04cf936978ab8f5d1005517c99715947ef871e246", size = 72510, upload-time = "2025-07-22T17:31:06.745Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/1a/3eacc4df8127781ee4b0b1e5cad7dbaf12510f58c42cbcb9d1e2dba2a164/langgraph_sdk-0.1.74-py3-none-any.whl", hash = "sha256:3a265c3757fe0048adad4391d10486db63ef7aa5a2cbd22da22d4503554cb890", size = 50254, upload-time = "2025-07-21T16:36:49.134Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/03/a8ab0e8ea74be6058cb48bb1d85485b5c65d6ea183e3ee1aa8ca1ac73b3e/langgraph_sdk-0.2.0-py3-none-any.whl", hash = "sha256:150722264f225c4d47bbe7394676be102fdbf04c4400a0dd1bd41a70c6430cc7", size = 50569, upload-time = "2025-07-22T17:31:04.582Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2731,6 +2731,31 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
Always use ``extra_body`` for custom parameters, **not** ``model_kwargs``.
|
||||
Using ``model_kwargs`` for non-OpenAI parameters will cause API errors.
|
||||
|
||||
.. dropdown:: Prompt caching optimization
|
||||
|
||||
For high-volume applications with repetitive prompts, use ``prompt_cache_key``
|
||||
per-invocation to improve cache hit rates and reduce costs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
llm = ChatOpenAI(model="gpt-4o-mini")
|
||||
|
||||
response = llm.invoke(
|
||||
messages,
|
||||
prompt_cache_key="example-key-a", # Routes to same machine for cache hits
|
||||
)
|
||||
|
||||
customer_response = llm.invoke(messages, prompt_cache_key="example-key-b")
|
||||
support_response = llm.invoke(messages, prompt_cache_key="example-key-c")
|
||||
|
||||
# Dynamic cache keys based on context
|
||||
cache_key = f"example-key-{dynamic_suffix}"
|
||||
response = llm.invoke(messages, prompt_cache_key=cache_key)
|
||||
|
||||
Cache keys help ensure requests with the same prompt prefix are routed to
|
||||
machines with existing cache, providing cost reduction and latency improvement on
|
||||
cached tokens.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||
@ -3716,6 +3741,20 @@ def _construct_responses_api_input(messages: Sequence[BaseMessage]) -> list:
|
||||
return input_
|
||||
|
||||
|
||||
def _get_output_text(response: Response) -> str:
|
||||
"""OpenAI SDK deleted response.output_text in 1.99.2"""
|
||||
if hasattr(response, "output_text"):
|
||||
return response.output_text
|
||||
texts: list[str] = []
|
||||
for output in response.output:
|
||||
if output.type == "message":
|
||||
for content in output.content:
|
||||
if content.type == "output_text":
|
||||
texts.append(content.text)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
def _construct_lc_result_from_responses_api(
|
||||
response: Response,
|
||||
schema: Optional[type[_BM]] = None,
|
||||
@ -3830,17 +3869,18 @@ def _construct_lc_result_from_responses_api(
|
||||
# text_format=Foo,
|
||||
# stream=True, # <-- errors
|
||||
# )
|
||||
output_text = _get_output_text(response)
|
||||
if (
|
||||
schema is not None
|
||||
and "parsed" not in additional_kwargs
|
||||
and response.output_text # tool calls can generate empty output text
|
||||
and output_text # tool calls can generate empty output text
|
||||
and response.text
|
||||
and (text_config := response.text.model_dump())
|
||||
and (format_ := text_config.get("format", {}))
|
||||
and (format_.get("type") == "json_schema")
|
||||
):
|
||||
try:
|
||||
parsed_dict = json.loads(response.output_text)
|
||||
parsed_dict = json.loads(output_text)
|
||||
if schema and _is_pydantic_class(schema):
|
||||
parsed = schema(**parsed_dict)
|
||||
else:
|
||||
|
@ -1112,3 +1112,46 @@ def test_tools_and_structured_output() -> None:
|
||||
assert isinstance(aggregated["raw"], AIMessage)
|
||||
assert aggregated["raw"].tool_calls
|
||||
assert aggregated["parsed"] is None
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_prompt_cache_key_invoke() -> None:
|
||||
"""Test that prompt_cache_key works with invoke calls."""
|
||||
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20)
|
||||
messages = [HumanMessage("Say hello")]
|
||||
|
||||
# Test that invoke works with prompt_cache_key parameter
|
||||
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
assert len(response.content) > 0
|
||||
|
||||
# Test that subsequent call with same cache key also works
|
||||
response2 = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||
|
||||
assert isinstance(response2, AIMessage)
|
||||
assert isinstance(response2.content, str)
|
||||
assert len(response2.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.scheduled
|
||||
def test_prompt_cache_key_usage_methods_integration() -> None:
|
||||
"""Integration test for prompt_cache_key usage methods."""
|
||||
messages = [HumanMessage("Say hi")]
|
||||
|
||||
# Test keyword argument method
|
||||
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||
response = chat.invoke(messages, prompt_cache_key="integration-test-v1")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
# Test model-level via model_kwargs
|
||||
chat_model_level = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
max_completion_tokens=10,
|
||||
model_kwargs={"prompt_cache_key": "integration-model-level-v1"},
|
||||
)
|
||||
response_model_level = chat_model_level.invoke(messages)
|
||||
assert isinstance(response_model_level, AIMessage)
|
||||
assert isinstance(response_model_level.content, str)
|
||||
|
@ -0,0 +1,84 @@
|
||||
"""Unit tests for prompt_cache_key parameter."""
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
def test_prompt_cache_key_parameter_inclusion() -> None:
|
||||
"""Test that prompt_cache_key parameter is properly included in request payload."""
|
||||
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||
messages = [HumanMessage("Hello")]
|
||||
|
||||
payload = chat._get_request_payload(messages, prompt_cache_key="test-cache-key")
|
||||
assert "prompt_cache_key" in payload
|
||||
assert payload["prompt_cache_key"] == "test-cache-key"
|
||||
|
||||
|
||||
def test_prompt_cache_key_parameter_exclusion() -> None:
|
||||
"""Test that prompt_cache_key parameter behavior matches OpenAI API."""
|
||||
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||
messages = [HumanMessage("Hello")]
|
||||
|
||||
# Test with explicit None (OpenAI should accept None values (marked Optional))
|
||||
payload = chat._get_request_payload(messages, prompt_cache_key=None)
|
||||
assert "prompt_cache_key" in payload
|
||||
assert payload["prompt_cache_key"] is None
|
||||
|
||||
|
||||
def test_prompt_cache_key_per_call() -> None:
|
||||
"""Test that prompt_cache_key can be passed per-call with different values."""
|
||||
chat = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=10)
|
||||
messages = [HumanMessage("Hello")]
|
||||
|
||||
# Test different cache keys per call
|
||||
payload1 = chat._get_request_payload(messages, prompt_cache_key="cache-v1")
|
||||
payload2 = chat._get_request_payload(messages, prompt_cache_key="cache-v2")
|
||||
|
||||
assert payload1["prompt_cache_key"] == "cache-v1"
|
||||
assert payload2["prompt_cache_key"] == "cache-v2"
|
||||
|
||||
# Test dynamic cache key assignment
|
||||
cache_keys = ["customer-v1", "support-v1", "feedback-v1"]
|
||||
|
||||
for cache_key in cache_keys:
|
||||
payload = chat._get_request_payload(messages, prompt_cache_key=cache_key)
|
||||
assert "prompt_cache_key" in payload
|
||||
assert payload["prompt_cache_key"] == cache_key
|
||||
|
||||
|
||||
def test_prompt_cache_key_model_kwargs() -> None:
|
||||
"""Test prompt_cache_key via model_kwargs and method precedence."""
|
||||
messages = [HumanMessage("Hello world")]
|
||||
|
||||
# Test model-level via model_kwargs
|
||||
chat = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
max_completion_tokens=10,
|
||||
model_kwargs={"prompt_cache_key": "model-level-cache"},
|
||||
)
|
||||
payload = chat._get_request_payload(messages)
|
||||
assert "prompt_cache_key" in payload
|
||||
assert payload["prompt_cache_key"] == "model-level-cache"
|
||||
|
||||
# Test that per-call cache key overrides model-level
|
||||
payload_override = chat._get_request_payload(
|
||||
messages, prompt_cache_key="per-call-cache"
|
||||
)
|
||||
assert payload_override["prompt_cache_key"] == "per-call-cache"
|
||||
|
||||
|
||||
def test_prompt_cache_key_responses_api() -> None:
|
||||
"""Test that prompt_cache_key works with Responses API."""
|
||||
chat = ChatOpenAI(
|
||||
model="gpt-4o-mini", use_responses_api=True, max_completion_tokens=10
|
||||
)
|
||||
|
||||
messages = [HumanMessage("Hello")]
|
||||
payload = chat._get_request_payload(
|
||||
messages, prompt_cache_key="responses-api-cache-v1"
|
||||
)
|
||||
|
||||
# prompt_cache_key should be present regardless of API type
|
||||
assert "prompt_cache_key" in payload
|
||||
assert payload["prompt_cache_key"] == "responses-api-cache-v1"
|
@ -995,7 +995,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.98.0"
|
||||
version = "1.99.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@ -1007,9 +1007,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d8/9d/52eadb15c92802711d6b6cf00df3a6d0d18b588f4c5ba5ff210c6419fc03/openai-1.98.0.tar.gz", hash = "sha256:3ee0fcc50ae95267fd22bd1ad095ba5402098f3df2162592e68109999f685427", size = 496695, upload-time = "2025-07-30T12:48:03.701Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/de/2c/8cd1684364551237a5e6db24ce25c25ff54efcf1805b39110ec713dc2972/openai-1.99.2.tar.gz", hash = "sha256:118075b48109aa237636607b1346cf03b37cb9d74b0414cb11095850a0a22c96", size = 504752, upload-time = "2025-08-07T17:16:14.668Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/fe/f64631075b3d63a613c0d8ab761d5941631a470f6fa87eaaee1aa2b4ec0c/openai-1.98.0-py3-none-any.whl", hash = "sha256:b99b794ef92196829120e2df37647722104772d2a74d08305df9ced5f26eae34", size = 767713, upload-time = "2025-07-30T12:48:01.264Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/06/f3338c1b8685dc1634fa5174dc5ba2d3eecc7887c9fc539bb5da6f75ebb3/openai-1.99.2-py3-none-any.whl", hash = "sha256:110d85b8ed400e1d7b02f8db8e245bd757bcde347cb6923155f42cd66a10aa0b", size = 785594, upload-time = "2025-08-07T17:16:13.083Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -24,8 +24,14 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
@property
|
||||
@abstractmethod
|
||||
def retriever_query_example(self) -> str:
|
||||
"""Returns a str representing the "query" of an example retriever call."""
|
||||
...
|
||||
"""Returns a str representing the ``query`` of an example retriever call."""
|
||||
|
||||
@property
|
||||
def num_results_arg_name(self) -> str:
|
||||
"""Returns the name of the parameter for the number of results returned.
|
||||
|
||||
Usually something like ``k`` or ``top_k``."""
|
||||
return "k"
|
||||
|
||||
@pytest.fixture
|
||||
def retriever(self) -> BaseRetriever:
|
||||
@ -33,14 +39,34 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
return self.retriever_constructor(**self.retriever_constructor_params)
|
||||
|
||||
def test_k_constructor_param(self) -> None:
|
||||
"""Test that the retriever constructor accepts a k parameter, representing
|
||||
"""Test the number of results constructor parameter.
|
||||
|
||||
Test that the retriever constructor accepts a parameter representing
|
||||
the number of documents to return.
|
||||
|
||||
By default, the parameter tested is named ``k``, but it can be overridden by
|
||||
setting the ``num_results_arg_name`` property.
|
||||
|
||||
.. note::
|
||||
If the retriever doesn't support configuring the number of results returned
|
||||
via the constructor, this test can be skipped using a pytest ``xfail`` on
|
||||
the test class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="This retriever doesn't support setting "
|
||||
"the number of results via the constructor."
|
||||
)
|
||||
def test_k_constructor_param(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, either the retriever constructor does not accept a k
|
||||
parameter, or the retriever does not return the correct number of documents
|
||||
(`k`) when it is set.
|
||||
If this test fails, the retriever constructor does not accept a number
|
||||
of results parameter, or the retriever does not return the correct number
|
||||
of documents ( of the one set in ``num_results_arg_name``) when it is
|
||||
set.
|
||||
|
||||
For example, a retriever like
|
||||
|
||||
@ -52,29 +78,51 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
|
||||
"""
|
||||
params = {
|
||||
k: v for k, v in self.retriever_constructor_params.items() if k != "k"
|
||||
k: v
|
||||
for k, v in self.retriever_constructor_params.items()
|
||||
if k != self.num_results_arg_name
|
||||
}
|
||||
params_3 = {**params, "k": 3}
|
||||
params_3 = {**params, self.num_results_arg_name: 3}
|
||||
retriever_3 = self.retriever_constructor(**params_3)
|
||||
result_3 = retriever_3.invoke(self.retriever_query_example)
|
||||
assert len(result_3) == 3
|
||||
assert all(isinstance(doc, Document) for doc in result_3)
|
||||
|
||||
params_1 = {**params, "k": 1}
|
||||
params_1 = {**params, self.num_results_arg_name: 1}
|
||||
retriever_1 = self.retriever_constructor(**params_1)
|
||||
result_1 = retriever_1.invoke(self.retriever_query_example)
|
||||
assert len(result_1) == 1
|
||||
assert all(isinstance(doc, Document) for doc in result_1)
|
||||
|
||||
def test_invoke_with_k_kwarg(self, retriever: BaseRetriever) -> None:
|
||||
"""Test that the invoke method accepts a k parameter, representing the number of
|
||||
documents to return.
|
||||
"""Test the number of results parameter in ``invoke()``.
|
||||
|
||||
Test that the invoke method accepts a parameter representing
|
||||
the number of documents to return.
|
||||
|
||||
By default, the parameter is named ``, but it can be overridden by
|
||||
setting the ``num_results_arg_name`` property.
|
||||
|
||||
.. note::
|
||||
If the retriever doesn't support configuring the number of results returned
|
||||
via the invoke method, this test can be skipped using a pytest ``xfail`` on
|
||||
the test class:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="This retriever doesn't support setting "
|
||||
"the number of results in the invoke method."
|
||||
)
|
||||
def test_invoke_with_k_kwarg(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, the retriever's invoke method does not accept a k
|
||||
parameter, or the retriever does not return the correct number of documents
|
||||
(`k`) when it is set.
|
||||
If this test fails, the retriever's invoke method does not accept a number
|
||||
of results parameter, or the retriever does not return the correct number
|
||||
of documents (``k`` of the one set in ``num_results_arg_name``) when it is
|
||||
set.
|
||||
|
||||
For example, a retriever like
|
||||
|
||||
@ -85,11 +133,15 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
should return 3 documents when invoked with a query.
|
||||
|
||||
"""
|
||||
result_1 = retriever.invoke(self.retriever_query_example, k=1)
|
||||
result_1 = retriever.invoke(
|
||||
self.retriever_query_example, None, **{self.num_results_arg_name: 1}
|
||||
)
|
||||
assert len(result_1) == 1
|
||||
assert all(isinstance(doc, Document) for doc in result_1)
|
||||
|
||||
result_3 = retriever.invoke(self.retriever_query_example, k=3)
|
||||
result_3 = retriever.invoke(
|
||||
self.retriever_query_example, None, **{self.num_results_arg_name: 3}
|
||||
)
|
||||
assert len(result_3) == 3
|
||||
assert all(isinstance(doc, Document) for doc in result_3)
|
||||
|
||||
@ -100,8 +152,8 @@ class RetrieversIntegrationTests(BaseStandardTests):
|
||||
.. dropdown:: Troubleshooting
|
||||
|
||||
If this test fails, the retriever's invoke method does not return a list of
|
||||
`langchain_core.document.Document` objects. Please confirm that your
|
||||
`_get_relevant_documents` method returns a list of `Document` objects.
|
||||
``langchain_core.document.Document`` objects. Please confirm that your
|
||||
``_get_relevant_documents`` method returns a list of ``Document`` objects.
|
||||
"""
|
||||
result = retriever.invoke(self.retriever_query_example)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user