Compare commits

...

13 Commits

Author SHA1 Message Date
Sydney Runkle
fd3acabe9d run in executor and middleware signatures 2025-09-30 16:41:36 -07:00
Sydney Runkle
348075987f adding tests 2025-09-30 13:48:57 -07:00
Sydney Runkle
ea5d6f2cfa correct handling for sync / async table 2025-09-30 13:10:25 -07:00
Sydney Runkle
cd9a12cc9b conditions for finding middleware 2025-09-30 12:50:01 -07:00
Sydney Runkle
33b11630fe another pass at async 2025-09-30 12:19:14 -07:00
Eugene Yurtsev
9c97597175 chore(langchain_v1): expose middleware decorators and selected messages (#33163)
* Make it easy to improve the middleware shortcuts
* Export the messages that we're confident we'll expose
2025-09-30 14:14:57 -04:00
Sydney Runkle
eed0f6c289 feat(langchain): todo middleware (#33152)
Porting the [planning
middleware](39c0138d0f/src/deepagents/middleware.py (L21))
over from deepagents.

Also adding the ability to configure:
* System prompt
* Tool description

```py
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents import create_agent

agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])

result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})

print(result["todos"])  # Array of todo items with status tracking
```
2025-09-30 02:23:26 +00:00
ccurme
729637a347 docs(anthropic): document support for memory tool and context management (#33149) 2025-09-29 16:38:01 -04:00
Mason Daugherty
3325196be1 fix(langchain): handle gpt-5 model name in init_chat_model (#33148)
expand to match any `gpt-*` model to openai
2025-09-29 16:16:17 -04:00
Mason Daugherty
f402fdcea3 fix(langchain): add context_management to Anthropic chat model init (#33150) 2025-09-29 16:13:47 -04:00
ccurme
ca9217c02d release(anthropic): 0.3.21 (#33147) 2025-09-29 19:56:28 +00:00
ccurme
f9bae40475 feat(anthropic): support memory and context management features (#33146)
https://docs.claude.com/en/docs/build-with-claude/context-editing

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
2025-09-29 15:42:38 -04:00
ccurme
839a18e112 fix(openai): remove __future__.annotations import from test files (#33144)
Breaks schema conversion in places.
2025-09-29 16:23:32 +00:00
21 changed files with 1391 additions and 103 deletions

View File

@@ -1191,6 +1191,40 @@
"response.content"
]
},
{
"cell_type": "markdown",
"id": "74247a07-b153-444f-9c56-77659aeefc88",
"metadata": {},
"source": [
"## Context management\n",
"\n",
"Anthropic supports a context editing feature that will automatically manage the model's context window (e.g., by clearing tool results).\n",
"\n",
"See [Anthropic documentation](https://docs.claude.com/en/docs/build-with-claude/context-editing) for details and configuration options.\n",
"\n",
":::info\n",
"Requires ``langchain-anthropic>=0.3.21``\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cbb79c5d-37b5-4212-b36f-f27366192cf9",
"metadata": {},
"outputs": [],
"source": [
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"llm = ChatAnthropic(\n",
" model=\"claude-sonnet-4-5-20250929\",\n",
" betas=[\"context-management-2025-06-27\"],\n",
" context_management={\"edits\": [{\"type\": \"clear_tool_uses_20250919\"}]},\n",
")\n",
"llm_with_tools = llm.bind_tools([{\"type\": \"web_search_20250305\", \"name\": \"web_search\"}])\n",
"response = llm_with_tools.invoke(\"Search for recent developments in AI\")"
]
},
{
"cell_type": "markdown",
"id": "cbfec7a9-d9df-4d12-844e-d922456dd9bf",
@@ -1457,6 +1491,38 @@
"</details>"
]
},
{
"cell_type": "markdown",
"id": "29405da2-d2ef-415c-b674-6e29073cd05e",
"metadata": {},
"source": [
"### Memory tool\n",
"\n",
"Claude supports a memory tool for client-side storage and retrieval of context across conversational threads. See docs [here](https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool) for details.\n",
"\n",
":::info\n",
"Requires ``langchain-anthropic>=0.3.21``\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbd76eaa-041f-4fb8-8346-ca8fe0001c01",
"metadata": {},
"outputs": [],
"source": [
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"llm = ChatAnthropic(\n",
" model=\"claude-sonnet-4-5-20250929\",\n",
" betas=[\"context-management-2025-06-27\"],\n",
")\n",
"llm_with_tools = llm.bind_tools([{\"type\": \"memory_20250818\", \"name\": \"memory\"}])\n",
"\n",
"response = llm_with_tools.invoke(\"What are my interests?\")"
]
},
{
"cell_type": "markdown",
"id": "040f381a-1768-479a-9a5e-aa2d7d77e0d5",

View File

@@ -118,7 +118,7 @@ def init_chat_model(
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- ``gpt-3...`` | ``gpt-4...`` | ``o1...`` -> ``openai``
- ``gpt-...`` | ``o1...`` | ``o3...`` -> ``openai``
- ``claude...`` -> ``anthropic``
- ``amazon...`` -> ``bedrock``
- ``gemini...`` -> ``google_vertexai``
@@ -497,7 +497,7 @@ _SUPPORTED_PROVIDERS = {
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
return "openai"
if model_name.startswith("claude"):
return "anthropic"

View File

@@ -270,6 +270,7 @@ def test_configurable_with_default() -> None:
"stop_sequences": None,
"anthropic_api_url": "https://api.anthropic.com",
"anthropic_proxy": None,
"context_management": None,
"anthropic_api_key": SecretStr("bar"),
"betas": None,
"default_headers": None,

View File

@@ -1,9 +1,17 @@
"""Middleware plugins for agents."""
from .human_in_the_loop import HumanInTheLoopMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .types import AgentMiddleware, AgentState, ModelRequest
from .types import (
AgentMiddleware,
AgentState,
ModelRequest,
after_model,
before_model,
modify_model_request,
)
__all__ = [
"AgentMiddleware",
@@ -12,5 +20,9 @@ __all__ = [
"AnthropicPromptCachingMiddleware",
"HumanInTheLoopMiddleware",
"ModelRequest",
"PlanningMiddleware",
"SummarizationMiddleware",
"after_model",
"before_model",
"modify_model_request",
]

View File

@@ -0,0 +1,197 @@
"""Planning and task management middleware for agents."""
# ruff: noqa: E501
from __future__ import annotations
from typing import Annotated, Literal
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.tools import InjectedToolCallId
class Todo(TypedDict):
"""A single todo item with content and status."""
content: str
"""The content/description of the todo item."""
status: Literal["pending", "in_progress", "completed"]
"""The current status of the todo item."""
class PlanningState(AgentState):
"""State schema for the todo middleware."""
todos: NotRequired[list[Todo]]
"""List of todo items for tracking task progress."""
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
## When to Use This Tool
Use this tool in these scenarios:
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
3. User explicitly requests todo list - When the user directly asks you to use the todo list
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
5. The plan may need future revisions or updates based on results from the first few steps
## How to Use This Tool
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
## When NOT to Use This Tool
It is important to skip using this tool when:
1. There is only a single, straightforward task
2. The task is trivial and tracking it provides no benefit
3. The task can be completed in less than 3 trivial steps
4. The task is purely conversational or informational
## Task States and Management
1. **Task States**: Use these states to track progress:
- pending: Task not yet started
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
- completed: Task finished successfully
2. **Task Management**:
- Update task status in real-time as you work
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
- Complete current tasks before starting new ones
- Remove tasks that are no longer relevant from the list entirely
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
3. **Task Completion Requirements**:
- ONLY mark a task as completed when you have FULLY accomplished it
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
- When blocked, create a new task describing what needs to be resolved
- Never mark a task as completed if:
- There are unresolved issues or errors
- Work is partial or incomplete
- You encountered blockers that prevent completion
- You couldn't find necessary resources or dependencies
- Quality standards haven't been met
4. **Task Breakdown**:
- Create specific, actionable items
- Break complex tasks into smaller, manageable steps
- Use clear, descriptive task names
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
You have access to the `write_todos` tool to help you manage and plan complex objectives.
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
## Important To-Do List Usage Notes to Remember
- The `write_todos` tool should never be called multiple times in parallel.
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
}
)
class PlanningMiddleware(AgentMiddleware):
"""Middleware that provides todo list management capabilities to agents.
This middleware adds a `write_todos` tool that allows agents to create and manage
structured task lists for complex multi-step operations. It's designed to help
agents track progress, organize complex tasks, and provide users with visibility
into task completion status.
The middleware automatically injects system prompts that guide the agent on when
and how to use the todo functionality effectively.
Example:
```python
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
# Agent now has access to write_todos tool and todo state tracking
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
print(result["todos"]) # Array of todo items with status tracking
```
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
tool_description: Custom description for the write_todos tool.
If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
"""
state_schema = PlanningState
def __init__(
self,
*,
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
) -> None:
"""Initialize the PlanningMiddleware with optional custom prompts.
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
tool_description: Custom description for the write_todos tool.
"""
super().__init__()
self.system_prompt = system_prompt
self.tool_description = tool_description
# Dynamically create the write_todos tool with the custom description
@tool(description=self.tool_description)
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
],
}
)
self.tools = [write_todos]
def modify_model_request( # type: ignore[override]
self,
request: ModelRequest,
state: PlanningState, # noqa: ARG002
) -> ModelRequest:
"""Update the system prompt to include the todo system prompt."""
request.system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return request

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from inspect import signature
from inspect import iscoroutinefunction, signature
from typing import (
TYPE_CHECKING,
Annotated,
@@ -18,6 +18,11 @@ from typing import (
overload,
)
from langchain_core.runnables import run_in_executor
if TYPE_CHECKING:
from collections.abc import Awaitable
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
@@ -129,6 +134,11 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called."""
async def abefore_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the model is called."""
def modify_model_request(
self,
request: ModelRequest,
@@ -138,14 +148,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
"""Logic to modify request kwargs before the model is called."""
return request
async def amodify_model_request(
self,
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
"""Async logic to modify request kwargs before the model is called."""
# Try calling sync version with runtime first, fall back to without runtime
try:
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
except TypeError:
# Sync version doesn't accept runtime, call without it
return await run_in_executor(None, self.modify_model_request, request, state)
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called."""
async def aafter_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the model is called."""
class _CallableWithState(Protocol[StateT_contra]):
"""Callable with AgentState as argument."""
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
def __call__(
self, state: StateT_contra
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
"""Perform some logic with the state."""
...
@@ -155,7 +186,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command | None:
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
"""Perform some logic with the state and runtime."""
...
@@ -163,7 +194,9 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
"""Callable with ModelRequest and AgentState as arguments."""
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
def __call__(
self, request: ModelRequest, state: StateT_contra
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform some logic with the model request and state."""
...
@@ -173,7 +206,7 @@ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, Contex
def __call__(
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
) -> ModelRequest:
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform some logic with the model request, state, and runtime."""
...
@@ -278,14 +311,53 @@ def before_model(
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return await func(state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast(
"str", getattr(func, "__name__", "BeforeModelMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model_jump_to": jump_to or [],
"abefore_model": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
return func(state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -298,7 +370,6 @@ def before_model(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
return type(
@@ -394,7 +465,47 @@ def modify_model_request(
def decorator(
func: _ModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime_and_request(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime_and_request(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return await func(request, state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
) -> ModelRequest:
return await func(request, state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast(
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"amodify_model_request": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
@@ -402,7 +513,7 @@ def modify_model_request(
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return func(request, state, runtime)
return func(request, state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -416,7 +527,6 @@ def modify_model_request(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
)
@@ -504,14 +614,51 @@ def after_model(
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return await func(state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model_jump_to": jump_to or [],
"aafter_model": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
return func(state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -524,7 +671,6 @@ def after_model(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(

View File

@@ -2,8 +2,9 @@
import itertools
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from inspect import signature
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -38,6 +39,27 @@ from langchain.tools import ToolNode
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
ResponseT = TypeVar("ResponseT")
@dataclass
class MiddlewareSignature(Generic[ResponseT, ContextT]):
"""Structured metadata for a middleware's hook implementations.
Attributes:
middleware: The middleware instance.
has_sync: Whether the middleware implements a sync version of the hook.
has_async: Whether the middleware implements an async version of the hook.
sync_uses_runtime: Whether the sync hook accepts a runtime argument.
async_uses_runtime: Whether the async hook accepts a runtime argument.
"""
middleware: AgentMiddleware[AgentState[ResponseT], ContextT]
has_sync: bool
has_async: bool
sync_uses_runtime: bool
async_uses_runtime: bool
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
@@ -130,9 +152,6 @@ def _handle_structured_output_error(
return False, ""
ResponseT = TypeVar("ResponseT")
def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
@@ -212,15 +231,22 @@ def create_agent( # noqa: PLR0915
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
m
for m in middleware
if m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
]
middleware_w_modify_model_request = [
m
for m in middleware
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
]
middleware_w_after = [
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
m
for m in middleware
if m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
]
state_schemas = {m.state_schema for m in middleware}
@@ -346,12 +372,32 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings)
model_request_signatures: list[
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
] = [
("runtime" in signature(m.modify_model_request).parameters, m)
for m in middleware_w_modify_model_request
]
# Build signatures for modify_model_request middleware with async support
model_request_signatures: list[MiddlewareSignature[ResponseT, ContextT]] = []
for m in middleware_w_modify_model_request:
# Check if async version is implemented (not the default)
has_sync = m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
has_async = m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
# Check runtime usage for each implementation
sync_uses_runtime = (
"runtime" in signature(m.modify_model_request).parameters if has_sync else False
)
# If async is implemented, check its signature; otherwise default is True
# (the default async implementation always expects runtime)
async_uses_runtime = (
"runtime" in signature(m.amodify_model_request).parameters if has_async else True
)
model_request_signatures.append(
MiddlewareSignature(
middleware=m,
has_sync=has_sync,
has_async=has_async,
sync_uses_runtime=sync_uses_runtime,
async_uses_runtime=async_uses_runtime,
)
)
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
@@ -365,11 +411,20 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
for sig in model_request_signatures:
if sig.has_sync:
if sig.sync_uses_runtime:
sig.middleware.modify_model_request(request, state, runtime)
else:
sig.middleware.modify_model_request(request, state) # type: ignore[call-arg]
elif sig.has_async:
msg = (
f"No synchronous function provided for "
f'{sig.middleware.__class__.__name__}.amodify_model_request".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
raise TypeError(msg)
# Get the final model and messages
model_ = _get_bound_model(request)
@@ -393,11 +448,13 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
for sig in model_request_signatures:
# If async is overridden and doesn't use runtime, call without it
if sig.has_async and not sig.async_uses_runtime:
await sig.middleware.amodify_model_request(request, state) # type: ignore[call-arg]
# Otherwise call async with runtime (default implementation handles sync delegation)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
await sig.middleware.amodify_model_request(request, state, runtime)
# Get the final model and messages
model_ = _get_bound_model(request)
@@ -419,14 +476,46 @@ def create_agent( # noqa: PLR0915
# Add middleware nodes
for m in middleware:
if m.__class__.before_model is not AgentMiddleware.before_model:
if (
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
)
before_node = RunnableCallable(sync_before, async_before)
graph.add_node(
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
)
if m.__class__.after_model is not AgentMiddleware.after_model:
if (
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
)
after_node = RunnableCallable(sync_after, async_after)
graph.add_node(
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
)
# add start edge

View File

@@ -109,7 +109,7 @@ def init_chat_model(
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- 'gpt-3...' | 'gpt-4...' | 'o1...' -> 'openai'
- 'gpt-...' | 'o1...' | 'o3...' -> 'openai'
- 'claude...' -> 'anthropic'
- 'amazon....' -> 'bedrock'
- 'gemini...' -> 'google_vertexai'
@@ -474,7 +474,7 @@ _SUPPORTED_PROVIDERS = {
def _attempt_infer_model_provider(model_name: str) -> str | None:
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
return "openai"
if model_name.startswith("claude"):
return "anthropic"

View File

@@ -0,0 +1,17 @@
"""Message types."""
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
)
__all__ = [
"AIMessage",
"AIMessageChunk",
"HumanMessage",
"SystemMessage",
"ToolMessage",
]

View File

@@ -1,14 +1,9 @@
import pytest
import warnings
from types import ModuleType
from typing import Any
from unittest.mock import patch
from types import ModuleType
from syrupy.assertion import SnapshotAssertion
import warnings
from langgraph.runtime import Runtime
from typing_extensions import Annotated
from pydantic import BaseModel, Field
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
@@ -18,31 +13,42 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.tools import tool, InjectedToolCallId
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import InjectedToolCallId, tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from langgraph.types import Command
from pydantic import BaseModel, Field
from syrupy.assertion import SnapshotAssertion
from typing_extensions import Annotated
from langchain.agents.middleware_agent import create_agent
from langchain.tools import InjectedState
from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware,
ActionRequest,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.planning import (
PlanningMiddleware,
PlanningState,
WRITE_TODOS_SYSTEM_PROMPT,
write_todos,
WRITE_TODOS_TOOL_DESCRIPTION,
)
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
AgentState,
ModelRequest,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
)
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent
from langchain.agents.structured_output import ToolStrategy
from langchain.tools import InjectedState
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel
@@ -1105,14 +1111,9 @@ def test_summarization_middleware_summary_creation() -> None:
class MockModel(BaseChatModel):
def invoke(self, prompt):
from langchain_core.messages import AIMessage
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1136,9 +1137,6 @@ def test_summarization_middleware_summary_creation() -> None:
raise Exception("Model error")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1155,14 +1153,9 @@ def test_summarization_middleware_full_workflow() -> None:
class MockModel(BaseChatModel):
def invoke(self, prompt):
from langchain_core.messages import AIMessage
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1423,3 +1416,481 @@ def test_jump_to_is_ephemeral() -> None:
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "jump_to" not in result
# Tests for PlanningMiddleware
def test_planning_middleware_initialization() -> None:
"""Test that PlanningMiddleware initializes correctly."""
middleware = PlanningMiddleware()
assert middleware.state_schema == PlanningState
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
@pytest.mark.parametrize(
"original_prompt,expected_prompt_prefix",
[
("Original prompt", "Original prompt\n\n## `write_todos`"),
(None, "## `write_todos`"),
],
)
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
"""Test that modify_model_request handles system prompts correctly."""
middleware = PlanningMiddleware()
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=original_prompt,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
@pytest.mark.parametrize(
"todos,expected_message",
[
([], "Updated todo list to []"),
(
[{"content": "Task 1", "status": "pending"}],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
{"content": "Task 3", "status": "completed"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
),
],
)
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
"""Test that the write_todos tool executes correctly."""
tool_call = {
"args": {"todos": todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
result = write_todos.invoke(tool_call)
assert result.update["todos"] == todos
assert result.update["messages"][0].content == expected_message
@pytest.mark.parametrize(
"invalid_todos",
[
[{"content": "Task 1", "status": "invalid_status"}],
[{"status": "pending"}],
],
)
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
"""Test that the write_todos tool rejects invalid input."""
tool_call = {
"args": {"todos": invalid_todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
with pytest.raises(Exception):
write_todos.invoke(tool_call)
def test_planning_middleware_agent_creation_with_middleware() -> None:
"""Test that an agent can be created with the planning middleware."""
model = FakeToolCallingModel(
tool_calls=[
[
{
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[],
]
)
middleware = PlanningMiddleware()
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
# human message (1)
# ai message (2) - initial todo
# tool message (3)
# ai message (4) - updated todo
# tool message (5)
# ai message (6) - complete todo
# tool message (7)
# ai message (8) - no tool calls
assert len(result["messages"]) == 8
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
custom_system_prompt = "Custom todo system prompt for testing"
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
def test_planning_middleware_custom_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with custom tool description."""
custom_tool_description = "Custom tool description for testing"
middleware = PlanningMiddleware(tool_description=custom_tool_description)
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
custom_system_prompt = "Custom system prompt"
custom_tool_description = "Custom tool description"
middleware = PlanningMiddleware(
system_prompt=custom_system_prompt,
tool_description=custom_tool_description,
)
# Verify system prompt
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=None,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt == custom_system_prompt
# Verify tool description
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_default_prompts() -> None:
"""Test that PlanningMiddleware uses default prompts when none provided."""
middleware = PlanningMiddleware()
# Verify default system prompt
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
# Verify default tool description
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that custom tool executes correctly in an agent."""
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
model = FakeToolCallingModel(
tool_calls=[
[
{
"args": {"todos": [{"content": "Custom task", "status": "pending"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[],
]
)
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
# assert custom system prompt is in the first AI message
assert "call the write_todos tool" in result["messages"][1].content
# Async Middleware Tests
async def test_create_agent_async_invoke() -> None:
"""Test async invoke with async middleware hooks."""
calls = []
class AsyncMiddleware(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddleware.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddleware.amodify_model_request")
request.messages.append(HumanMessage("async middleware message"))
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddleware.aafter_model")
@tool
def my_tool(input: str) -> str:
"""A great tool"""
calls.append("my_tool")
return input.upper()
agent = create_agent(
model=FakeToolCallingModel(
tool_calls=[
[{"args": {"input": "yo"}, "id": "1", "name": "my_tool"}],
[],
]
),
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddleware()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
# Should have:
# 1. Original hello message
# 2. Async middleware message (first invoke)
# 3. AI message with tool call
# 4. Tool message
# 5. Async middleware message (second invoke)
# 6. Final AI message
assert len(result["messages"]) == 6
assert result["messages"][0].content == "hello"
assert result["messages"][1].content == "async middleware message"
assert calls == [
"AsyncMiddleware.abefore_model",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
"my_tool",
"AsyncMiddleware.abefore_model",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
]
async def test_create_agent_async_invoke_multiple_middleware() -> None:
"""Test async invoke with multiple async middleware hooks."""
calls = []
class AsyncMiddlewareOne(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddlewareOne.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.aafter_model")
class AsyncMiddlewareTwo(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareTwo.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddlewareTwo.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddlewareTwo.aafter_model")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
assert calls == [
"AsyncMiddlewareOne.abefore_model",
"AsyncMiddlewareTwo.abefore_model",
"AsyncMiddlewareOne.amodify_model_request",
"AsyncMiddlewareTwo.amodify_model_request",
"AsyncMiddlewareTwo.aafter_model",
"AsyncMiddlewareOne.aafter_model",
]
async def test_create_agent_async_jump() -> None:
"""Test async invoke with async middleware using jump_to."""
calls = []
class AsyncMiddlewareOne(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.abefore_model")
class AsyncMiddlewareTwo(AgentMiddleware):
before_model_jump_to = ["end"]
async def abefore_model(self, state) -> dict[str, Any]:
calls.append("AsyncMiddlewareTwo.abefore_model")
return {"jump_to": "end"}
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
result = await agent.ainvoke({"messages": []})
assert result == {"messages": []}
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]
async def test_create_agent_mixed_sync_async_middleware() -> None:
"""Test async invoke with mixed sync and async middleware."""
calls = []
class SyncMiddleware(AgentMiddleware):
def before_model(self, state) -> None:
calls.append("SyncMiddleware.before_model")
def modify_model_request(self, request, state) -> ModelRequest:
calls.append("SyncMiddleware.modify_model_request")
return request
def after_model(self, state) -> None:
calls.append("SyncMiddleware.after_model")
class AsyncMiddleware(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddleware.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddleware.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddleware.aafter_model")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[SyncMiddleware(), AsyncMiddleware()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
# In async mode, both sync and async middleware should work
assert calls == [
"SyncMiddleware.before_model",
"AsyncMiddleware.abefore_model",
"SyncMiddleware.modify_model_request",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
"SyncMiddleware.after_model",
]
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None:
"""Test that sync invoke with only async middleware raises TypeError."""
class AsyncOnlyMiddleware(AgentMiddleware):
async def amodify_model_request(self, request, state) -> ModelRequest:
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncOnlyMiddleware()],
).compile()
with pytest.raises(
TypeError,
match=r"No synchronous function provided for AsyncOnlyMiddleware\.amodify_model_request",
):
agent.invoke({"messages": [HumanMessage("hello")]})
def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
calls = []
class MixedMiddleware(AgentMiddleware):
def before_model(self, state) -> None:
calls.append("MixedMiddleware.before_model")
async def abefore_model(self, state) -> None:
calls.append("MixedMiddleware.abefore_model")
def modify_model_request(self, request, state) -> ModelRequest:
calls.append("MixedMiddleware.modify_model_request")
return request
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("MixedMiddleware.amodify_model_request")
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[MixedMiddleware()],
).compile()
result = agent.invoke({"messages": [HumanMessage("hello")]})
# In sync mode, only sync methods should be called
assert calls == [
"MixedMiddleware.before_model",
"MixedMiddleware.modify_model_request",
]

View File

@@ -1,5 +1,6 @@
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
import pytest
from typing import Any
from typing_extensions import NotRequired
@@ -150,3 +151,161 @@ def test_decorators_use_function_names_as_default() -> None:
assert my_before_hook.__class__.__name__ == "my_before_hook"
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
assert my_after_hook.__class__.__name__ == "my_after_hook"
# Async Decorator Tests
def test_async_before_model_decorator() -> None:
"""Test before_model decorator with async function."""
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
async def async_before_model(state: CustomState) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_before_model, AgentMiddleware)
assert async_before_model.state_schema == CustomState
assert async_before_model.tools == [test_tool]
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
def test_async_after_model_decorator() -> None:
"""Test after_model decorator with async function."""
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
async def async_after_model(state: CustomState) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_after_model, AgentMiddleware)
assert async_after_model.state_schema == CustomState
assert async_after_model.tools == [test_tool]
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
def test_async_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with async function."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
async def async_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
request.system_prompt = "Modified async"
return request
assert isinstance(async_modify_request, AgentMiddleware)
assert async_modify_request.state_schema == CustomState
assert async_modify_request.tools == [test_tool]
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
def test_mixed_sync_async_decorators() -> None:
"""Test decorators with both sync and async functions."""
@before_model(name="MixedBeforeModel")
def sync_before(state: AgentState) -> None:
return None
@before_model(name="MixedBeforeModel")
async def async_before(state: AgentState) -> None:
return None
@modify_model_request(name="MixedModifyRequest")
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
@modify_model_request(name="MixedModifyRequest")
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
# Both should create valid middleware instances
assert isinstance(sync_before, AgentMiddleware)
assert isinstance(async_before, AgentMiddleware)
assert isinstance(sync_modify, AgentMiddleware)
assert isinstance(async_modify, AgentMiddleware)
@pytest.mark.asyncio
async def test_async_decorators_integration() -> None:
"""Test async decorators working together in an agent."""
call_order = []
@before_model
async def track_async_before(state: AgentState) -> None:
call_order.append("async_before")
return None
@modify_model_request
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState) -> None:
call_order.append("async_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[track_async_before, track_async_modify, track_async_after],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["async_before", "async_modify", "async_after"]
@pytest.mark.asyncio
async def test_mixed_sync_async_decorators_integration() -> None:
"""Test mixed sync/async decorators working together in an agent."""
call_order = []
@before_model
def track_sync_before(state: AgentState) -> None:
call_order.append("sync_before")
return None
@before_model
async def track_async_before(state: AgentState) -> None:
call_order.append("async_before")
return None
@modify_model_request
def track_sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("sync_modify")
return request
@modify_model_request
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState) -> None:
call_order.append("async_after")
return None
@after_model
def track_sync_after(state: AgentState) -> None:
call_order.append("sync_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[
track_sync_before,
track_async_before,
track_sync_modify,
track_async_modify,
track_async_after,
track_sync_after,
],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == [
"sync_before",
"async_before",
"sync_modify",
"async_modify",
"sync_after",
"async_after",
]

View File

@@ -267,6 +267,7 @@ def test_configurable_with_default() -> None:
"stop_sequences": None,
"anthropic_api_url": "https://api.anthropic.com",
"anthropic_proxy": None,
"context_management": None,
"anthropic_api_key": SecretStr("bar"),
"betas": None,
"default_headers": None,

View File

@@ -12,7 +12,7 @@ from operator import itemgetter
from typing import Any, Callable, Literal, Optional, Union, cast
import anthropic
from langchain_core._api import beta, deprecated
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -91,6 +91,7 @@ def _is_builtin_tool(tool: Any) -> bool:
"web_search_",
"web_fetch_",
"code_execution_",
"memory_",
]
return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes)
@@ -1193,6 +1194,25 @@ class ChatAnthropic(BaseChatModel):
Total tokens: 408
Context management:
Anthropic supports a context editing feature that will automatically manage the
model's context window (e.g., by clearing tool results).
See `Anthropic documentation <https://docs.claude.com/en/docs/build-with-claude/context-editing>`__
for details and configuration options.
.. code-block:: python
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929",
betas=["context-management-2025-06-27"],
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
)
llm_with_tools = llm.bind_tools([{"type": "web_search_20250305", "name": "web_search"}])
response = llm_with_tools.invoke("Search for recent developments in AI")
Built-in tools:
See LangChain `docs <https://python.langchain.com/docs/integrations/chat/anthropic/#built-in-tools>`__
for more detail.
@@ -1306,6 +1326,19 @@ class ChatAnthropic(BaseChatModel):
'id': 'toolu_01VdNgt1YV7kGfj9LFLm6HyQ',
'type': 'tool_call'}]
.. dropdown:: Memory tool
.. code-block:: python
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929",
betas=["context-management-2025-06-27"],
)
llm_with_tools = llm.bind_tools([{"type": "memory_20250818", "name": "memory"}])
response = llm_with_tools.invoke("What are my interests?")
Response metadata
.. code-block:: python
@@ -1413,6 +1446,11 @@ class ChatAnthropic(BaseChatModel):
"name": "example-mcp"}]``
"""
context_management: Optional[dict[str, Any]] = None
"""Configuration for
`context management <https://docs.claude.com/en/docs/build-with-claude/context-editing>`__.
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
@@ -1565,6 +1603,7 @@ class ChatAnthropic(BaseChatModel):
"top_p": self.top_p,
"stop_sequences": stop or self.stop_sequences,
"betas": self.betas,
"context_management": self.context_management,
"mcp_servers": self.mcp_servers,
"system": system,
**self.model_kwargs,
@@ -2219,7 +2258,6 @@ class ChatAnthropic(BaseChatModel):
return RunnableMap(raw=llm) | parser_with_fallback
return llm | output_parser
@beta()
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
@@ -2234,8 +2272,8 @@ class ChatAnthropic(BaseChatModel):
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
kwargs: Additional keyword arguments are passed to the
:meth:`~langchain_anthropic.chat_models.ChatAnthropic.bind` method.
kwargs: Additional keyword arguments are passed to the Anthropic
``messages.count_tokens`` method.
Basic usage:
@@ -2270,7 +2308,7 @@ class ChatAnthropic(BaseChatModel):
def get_weather(location: str) -> str:
\"\"\"Get the current weather in a given location
Args:
Args:
location: The city and state, e.g. San Francisco, CA
\"\"\"
return "Sunny"
@@ -2288,15 +2326,24 @@ class ChatAnthropic(BaseChatModel):
Uses Anthropic's `token counting API <https://docs.anthropic.com/en/docs/build-with-claude/token-counting>`__ to count tokens in messages.
""" # noqa: E501
""" # noqa: D214,E501
formatted_system, formatted_messages = _format_messages(messages)
if isinstance(formatted_system, str):
kwargs["system"] = formatted_system
if tools:
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
if self.context_management is not None:
kwargs["context_management"] = self.context_management
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
if self.betas is not None:
beta_response = self._client.beta.messages.count_tokens(
betas=self.betas,
model=self.model,
messages=formatted_messages, # type: ignore[arg-type]
**kwargs,
)
return beta_response.input_tokens
response = self._client.messages.count_tokens(
model=self.model,
messages=formatted_messages, # type: ignore[arg-type]
**kwargs,
@@ -2409,7 +2456,7 @@ def _make_message_chunk_from_anthropic_event(
# Capture model name, but don't include usage_metadata yet
# as it will be properly reported in message_delta with complete info
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
response_metadata: dict[str, Any] = {"model_name": event.message.model}
else:
response_metadata = {}
@@ -2510,13 +2557,16 @@ def _make_message_chunk_from_anthropic_event(
# Process final usage metadata and completion info
elif event.type == "message_delta" and stream_usage:
usage_metadata = _create_usage_metadata(event.usage)
response_metadata = {
"stop_reason": event.delta.stop_reason,
"stop_sequence": event.delta.stop_sequence,
}
if context_management := getattr(event, "context_management", None):
response_metadata["context_management"] = context_management.model_dump()
message_chunk = AIMessageChunk(
content="",
usage_metadata=usage_metadata,
response_metadata={
"stop_reason": event.delta.stop_reason,
"stop_sequence": event.delta.stop_sequence,
},
response_metadata=response_metadata,
)
# Unhandled event types (e.g., `content_block_stop`, `ping` events)
# https://docs.anthropic.com/en/docs/build-with-claude/streaming#other-events

View File

@@ -7,12 +7,12 @@ authors = []
license = { text = "MIT" }
requires-python = ">=3.9.0,<4.0.0"
dependencies = [
"anthropic>=0.67.0,<1.0.0",
"anthropic>=0.69.0,<1.0.0",
"langchain-core>=0.3.76,<2.0.0",
"pydantic>=2.7.4,<3.0.0",
]
name = "langchain-anthropic"
version = "0.3.20"
version = "0.3.21"
description = "An integration package connecting Anthropic and LangChain"
readme = "README.md"

View File

@@ -1485,6 +1485,50 @@ def test_search_result_top_level() -> None:
assert any("citations" in block for block in result.content)
def test_memory_tool() -> None:
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
betas=["context-management-2025-06-27"],
)
llm_with_tools = llm.bind_tools([{"type": "memory_20250818", "name": "memory"}])
response = llm_with_tools.invoke("What are my interests?")
assert isinstance(response, AIMessage)
assert response.tool_calls
assert response.tool_calls[0]["name"] == "memory"
@pytest.mark.vcr
def test_context_management() -> None:
# TODO: update example to trigger action
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
betas=["context-management-2025-06-27"],
context_management={
"edits": [
{
"type": "clear_tool_uses_20250919",
"trigger": {"type": "input_tokens", "value": 10},
"clear_at_least": {"type": "input_tokens", "value": 5},
}
]
},
)
llm_with_tools = llm.bind_tools(
[{"type": "web_search_20250305", "name": "web_search"}]
)
input_message = {"role": "user", "content": "Search for recent developments in AI"}
response = llm_with_tools.invoke([input_message])
assert response.response_metadata.get("context_management")
# Test streaming
full: Optional[BaseMessageChunk] = None
for chunk in llm_with_tools.stream([input_message]):
assert isinstance(chunk, AIMessageChunk)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("context_management")
def test_async_shared_client() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest") # type: ignore[call-arg]
_ = asyncio.run(llm.ainvoke("Hello"))

View File

@@ -1065,9 +1065,21 @@ def test_get_num_tokens_from_messages_passes_kwargs() -> None:
with patch.object(anthropic, "Client") as _client:
llm.get_num_tokens_from_messages([HumanMessage("foo")], foo="bar")
assert (
_client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar"
assert _client.return_value.messages.count_tokens.call_args.kwargs["foo"] == "bar"
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929",
betas=["context-management-2025-06-27"],
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
)
with patch.object(anthropic, "Client") as _client:
llm.get_num_tokens_from_messages([HumanMessage("foo")])
call_args = _client.return_value.beta.messages.count_tokens.call_args.kwargs
assert call_args["betas"] == ["context-management-2025-06-27"]
assert call_args["context_management"] == {
"edits": [{"type": "clear_tool_uses_20250919"}]
}
def test_usage_metadata_standardization() -> None:
@@ -1217,6 +1229,22 @@ def test_cache_control_kwarg() -> None:
]
def test_context_management_in_payload() -> None:
llm = ChatAnthropic(
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
betas=["context-management-2025-06-27"],
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
)
llm_with_tools = llm.bind_tools(
[{"type": "web_search_20250305", "name": "web_search"}]
)
input_message = HumanMessage("Search for recent developments in AI")
payload = llm_with_tools._get_request_payload([input_message]) # type: ignore[attr-defined]
assert payload["context_management"] == {
"edits": [{"type": "clear_tool_uses_20250919"}]
}
def test_anthropic_model_params() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-latest")

View File

@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.9.0, <4.0.0"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@@ -21,20 +21,21 @@ wheels = [
[[package]]
name = "anthropic"
version = "0.67.0"
version = "0.69.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "distro" },
{ name = "docstring-parser" },
{ name = "httpx" },
{ name = "jiter" },
{ name = "pydantic" },
{ name = "sniffio" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/09/08/ee91464cd821e6fca52d9a23be44815c95edd3c1cf1e844b2c5e85f0d57f/anthropic-0.67.0.tar.gz", hash = "sha256:d1531b210ea300c73423141d29bcee20fcd24ef9f426f6437c0a5d93fc98fb8e", size = 441639, upload-time = "2025-09-10T14:47:18.137Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/9d/9adbda372710918cc8271d089a2ceae4d977a125f90bc3c4b456bca4f281/anthropic-0.67.0-py3-none-any.whl", hash = "sha256:f80a81ec1132c514215f33d25edeeab1c4691ad5361b391ebb70d528b0605b55", size = 317126, upload-time = "2025-09-10T14:47:16.351Z" },
{ url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" },
]
[[package]]
@@ -257,6 +258,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
]
[[package]]
name = "docstring-parser"
version = "0.17.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" },
]
[[package]]
name = "exceptiongroup"
version = "1.3.0"
@@ -455,7 +465,7 @@ wheels = [
[[package]]
name = "langchain-anthropic"
version = "0.3.20"
version = "0.3.21"
source = { editable = "." }
dependencies = [
{ name = "anthropic" },
@@ -497,7 +507,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.67.0,<1.0.0" },
{ name = "anthropic", specifier = ">=0.69.0,<1.0.0" },
{ name = "langchain-core", editable = "../../core" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]

View File

@@ -85,6 +85,7 @@ ignore = [
"RUF001",
"ERA001",
"PLR0911",
"FA100", # from __future__ import annotations breaks some schema conversion logic
# TODO
"PLR2004", # Comparison to magic number

View File

@@ -1,7 +1,5 @@
"""Test ChatOpenAI chat model."""
from __future__ import annotations
import base64
import json
from collections.abc import AsyncIterator

View File

@@ -1,7 +1,5 @@
"""Test Responses API usage."""
from __future__ import annotations
import json
import os
from typing import Annotated, Any, Literal, Optional, cast