mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 12:11:34 +00:00
Compare commits
13 Commits
langchain=
...
sr/async-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3acabe9d | ||
|
|
348075987f | ||
|
|
ea5d6f2cfa | ||
|
|
cd9a12cc9b | ||
|
|
33b11630fe | ||
|
|
9c97597175 | ||
|
|
eed0f6c289 | ||
|
|
729637a347 | ||
|
|
3325196be1 | ||
|
|
f402fdcea3 | ||
|
|
ca9217c02d | ||
|
|
f9bae40475 | ||
|
|
839a18e112 |
@@ -1191,6 +1191,40 @@
|
||||
"response.content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "74247a07-b153-444f-9c56-77659aeefc88",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Context management\n",
|
||||
"\n",
|
||||
"Anthropic supports a context editing feature that will automatically manage the model's context window (e.g., by clearing tool results).\n",
|
||||
"\n",
|
||||
"See [Anthropic documentation](https://docs.claude.com/en/docs/build-with-claude/context-editing) for details and configuration options.\n",
|
||||
"\n",
|
||||
":::info\n",
|
||||
"Requires ``langchain-anthropic>=0.3.21``\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cbb79c5d-37b5-4212-b36f-f27366192cf9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"\n",
|
||||
"llm = ChatAnthropic(\n",
|
||||
" model=\"claude-sonnet-4-5-20250929\",\n",
|
||||
" betas=[\"context-management-2025-06-27\"],\n",
|
||||
" context_management={\"edits\": [{\"type\": \"clear_tool_uses_20250919\"}]},\n",
|
||||
")\n",
|
||||
"llm_with_tools = llm.bind_tools([{\"type\": \"web_search_20250305\", \"name\": \"web_search\"}])\n",
|
||||
"response = llm_with_tools.invoke(\"Search for recent developments in AI\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cbfec7a9-d9df-4d12-844e-d922456dd9bf",
|
||||
@@ -1457,6 +1491,38 @@
|
||||
"</details>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29405da2-d2ef-415c-b674-6e29073cd05e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Memory tool\n",
|
||||
"\n",
|
||||
"Claude supports a memory tool for client-side storage and retrieval of context across conversational threads. See docs [here](https://docs.claude.com/en/docs/agents-and-tools/tool-use/memory-tool) for details.\n",
|
||||
"\n",
|
||||
":::info\n",
|
||||
"Requires ``langchain-anthropic>=0.3.21``\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bbd76eaa-041f-4fb8-8346-ca8fe0001c01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"\n",
|
||||
"llm = ChatAnthropic(\n",
|
||||
" model=\"claude-sonnet-4-5-20250929\",\n",
|
||||
" betas=[\"context-management-2025-06-27\"],\n",
|
||||
")\n",
|
||||
"llm_with_tools = llm.bind_tools([{\"type\": \"memory_20250818\", \"name\": \"memory\"}])\n",
|
||||
"\n",
|
||||
"response = llm_with_tools.invoke(\"What are my interests?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "040f381a-1768-479a-9a5e-aa2d7d77e0d5",
|
||||
|
||||
@@ -118,7 +118,7 @@ def init_chat_model(
|
||||
Will attempt to infer model_provider from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
|
||||
- ``gpt-3...`` | ``gpt-4...`` | ``o1...`` -> ``openai``
|
||||
- ``gpt-...`` | ``o1...`` | ``o3...`` -> ``openai``
|
||||
- ``claude...`` -> ``anthropic``
|
||||
- ``amazon...`` -> ``bedrock``
|
||||
- ``gemini...`` -> ``google_vertexai``
|
||||
@@ -497,7 +497,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
|
||||
return "openai"
|
||||
if model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
|
||||
@@ -270,6 +270,7 @@ def test_configurable_with_default() -> None:
|
||||
"stop_sequences": None,
|
||||
"anthropic_api_url": "https://api.anthropic.com",
|
||||
"anthropic_proxy": None,
|
||||
"context_management": None,
|
||||
"anthropic_api_key": SecretStr("bar"),
|
||||
"betas": None,
|
||||
"default_headers": None,
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .types import AgentMiddleware, AgentState, ModelRequest
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
after_model,
|
||||
before_model,
|
||||
modify_model_request,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
@@ -12,5 +20,9 @@ __all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"ModelRequest",
|
||||
"PlanningMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"after_model",
|
||||
"before_model",
|
||||
"modify_model_request",
|
||||
]
|
||||
|
||||
197
libs/langchain_v1/langchain/agents/middleware/planning.py
Normal file
197
libs/langchain_v1/langchain/agents/middleware/planning.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Planning and task management middleware for agents."""
|
||||
# ruff: noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.tools import InjectedToolCallId
|
||||
|
||||
|
||||
class Todo(TypedDict):
|
||||
"""A single todo item with content and status."""
|
||||
|
||||
content: str
|
||||
"""The content/description of the todo item."""
|
||||
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
"""The current status of the todo item."""
|
||||
|
||||
|
||||
class PlanningState(AgentState):
|
||||
"""State schema for the todo middleware."""
|
||||
|
||||
todos: NotRequired[list[Todo]]
|
||||
"""List of todo items for tracking task progress."""
|
||||
|
||||
|
||||
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
|
||||
|
||||
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
|
||||
|
||||
## When to Use This Tool
|
||||
Use this tool in these scenarios:
|
||||
|
||||
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
|
||||
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
|
||||
3. User explicitly requests todo list - When the user directly asks you to use the todo list
|
||||
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
|
||||
5. The plan may need future revisions or updates based on results from the first few steps
|
||||
|
||||
## How to Use This Tool
|
||||
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
|
||||
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
|
||||
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
|
||||
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
|
||||
|
||||
## When NOT to Use This Tool
|
||||
It is important to skip using this tool when:
|
||||
1. There is only a single, straightforward task
|
||||
2. The task is trivial and tracking it provides no benefit
|
||||
3. The task can be completed in less than 3 trivial steps
|
||||
4. The task is purely conversational or informational
|
||||
|
||||
## Task States and Management
|
||||
|
||||
1. **Task States**: Use these states to track progress:
|
||||
- pending: Task not yet started
|
||||
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
|
||||
- completed: Task finished successfully
|
||||
|
||||
2. **Task Management**:
|
||||
- Update task status in real-time as you work
|
||||
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
||||
- Complete current tasks before starting new ones
|
||||
- Remove tasks that are no longer relevant from the list entirely
|
||||
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
|
||||
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
|
||||
|
||||
3. **Task Completion Requirements**:
|
||||
- ONLY mark a task as completed when you have FULLY accomplished it
|
||||
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
|
||||
- When blocked, create a new task describing what needs to be resolved
|
||||
- Never mark a task as completed if:
|
||||
- There are unresolved issues or errors
|
||||
- Work is partial or incomplete
|
||||
- You encountered blockers that prevent completion
|
||||
- You couldn't find necessary resources or dependencies
|
||||
- Quality standards haven't been met
|
||||
|
||||
4. **Task Breakdown**:
|
||||
- Create specific, actionable items
|
||||
- Break complex tasks into smaller, manageable steps
|
||||
- Use clear, descriptive task names
|
||||
|
||||
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
|
||||
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
|
||||
|
||||
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
|
||||
|
||||
You have access to the `write_todos` tool to help you manage and plan complex objectives.
|
||||
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
|
||||
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
|
||||
|
||||
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
|
||||
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
|
||||
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
|
||||
|
||||
## Important To-Do List Usage Notes to Remember
|
||||
- The `write_todos` tool should never be called multiple times in parallel.
|
||||
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
|
||||
|
||||
|
||||
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
||||
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PlanningMiddleware(AgentMiddleware):
|
||||
"""Middleware that provides todo list management capabilities to agents.
|
||||
|
||||
This middleware adds a `write_todos` tool that allows agents to create and manage
|
||||
structured task lists for complex multi-step operations. It's designed to help
|
||||
agents track progress, organize complex tasks, and provide users with visibility
|
||||
into task completion status.
|
||||
|
||||
The middleware automatically injects system prompts that guide the agent on when
|
||||
and how to use the todo functionality effectively.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.planning import PlanningMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
|
||||
|
||||
# Agent now has access to write_todos tool and todo state tracking
|
||||
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
||||
|
||||
print(result["todos"]) # Array of todo items with status tracking
|
||||
```
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
|
||||
"""
|
||||
|
||||
state_schema = PlanningState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
||||
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
) -> None:
|
||||
"""Initialize the PlanningMiddleware with optional custom prompts.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
self.tool_description = tool_description
|
||||
|
||||
# Dynamically create the write_todos tool with the custom description
|
||||
@tool(description=self.tool_description)
|
||||
def write_todos(
|
||||
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> Command:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [
|
||||
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
self.tools = [write_todos]
|
||||
|
||||
def modify_model_request( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: PlanningState, # noqa: ARG002
|
||||
) -> ModelRequest:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
request.system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return request
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import signature
|
||||
from inspect import iscoroutinefunction, signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -18,6 +18,11 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
@@ -129,6 +134,11 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
async def abefore_model(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run before the model is called."""
|
||||
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -138,14 +148,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
async def amodify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
"""Async logic to modify request kwargs before the model is called."""
|
||||
# Try calling sync version with runtime first, fall back to without runtime
|
||||
try:
|
||||
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
||||
except TypeError:
|
||||
# Sync version doesn't accept runtime, call without it
|
||||
return await run_in_executor(None, self.modify_model_request, request, state)
|
||||
|
||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run after the model is called."""
|
||||
|
||||
async def aafter_model(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run after the model is called."""
|
||||
|
||||
|
||||
class _CallableWithState(Protocol[StateT_contra]):
|
||||
"""Callable with AgentState as argument."""
|
||||
|
||||
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
|
||||
def __call__(
|
||||
self, state: StateT_contra
|
||||
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
||||
"""Perform some logic with the state."""
|
||||
...
|
||||
|
||||
@@ -155,7 +186,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
|
||||
def __call__(
|
||||
self, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | Command | None:
|
||||
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
||||
"""Perform some logic with the state and runtime."""
|
||||
...
|
||||
|
||||
@@ -163,7 +194,9 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
|
||||
"""Callable with ModelRequest and AgentState as arguments."""
|
||||
|
||||
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform some logic with the model request and state."""
|
||||
...
|
||||
|
||||
@@ -173,7 +206,7 @@ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, Contex
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> ModelRequest:
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform some logic with the model request, state, and runtime."""
|
||||
...
|
||||
|
||||
@@ -278,14 +311,53 @@ def before_model(
|
||||
"""
|
||||
|
||||
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "BeforeModelMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_model_jump_to": jump_to or [],
|
||||
"abefore_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -298,7 +370,6 @@ def before_model(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
||||
|
||||
return type(
|
||||
@@ -394,7 +465,47 @@ def modify_model_request(
|
||||
def decorator(
|
||||
func: _ModelRequestSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime_and_request(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime_and_request(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
return await func(request, state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
) -> ModelRequest:
|
||||
return await func(request, state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"amodify_model_request": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
@@ -402,7 +513,7 @@ def modify_model_request(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
return func(request, state, runtime)
|
||||
return func(request, state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -416,7 +527,6 @@ def modify_model_request(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
||||
)
|
||||
@@ -504,14 +614,51 @@ def after_model(
|
||||
"""
|
||||
|
||||
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_model_jump_to": jump_to or [],
|
||||
"aafter_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -524,7 +671,6 @@ def after_model(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from inspect import signature
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
@@ -38,6 +39,27 @@ from langchain.tools import ToolNode
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiddlewareSignature(Generic[ResponseT, ContextT]):
|
||||
"""Structured metadata for a middleware's hook implementations.
|
||||
|
||||
Attributes:
|
||||
middleware: The middleware instance.
|
||||
has_sync: Whether the middleware implements a sync version of the hook.
|
||||
has_async: Whether the middleware implements an async version of the hook.
|
||||
sync_uses_runtime: Whether the sync hook accepts a runtime argument.
|
||||
async_uses_runtime: Whether the async hook accepts a runtime argument.
|
||||
"""
|
||||
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT]
|
||||
has_sync: bool
|
||||
has_async: bool
|
||||
sync_uses_runtime: bool
|
||||
async_uses_runtime: bool
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
@@ -130,9 +152,6 @@ def _handle_structured_output_error(
|
||||
return False, ""
|
||||
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
@@ -212,15 +231,22 @@ def create_agent( # noqa: PLR0915
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
]
|
||||
middleware_w_modify_model_request = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
]
|
||||
middleware_w_after = [
|
||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
]
|
||||
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
@@ -346,12 +372,32 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
model_request_signatures: list[
|
||||
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
||||
] = [
|
||||
("runtime" in signature(m.modify_model_request).parameters, m)
|
||||
for m in middleware_w_modify_model_request
|
||||
]
|
||||
# Build signatures for modify_model_request middleware with async support
|
||||
model_request_signatures: list[MiddlewareSignature[ResponseT, ContextT]] = []
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Check if async version is implemented (not the default)
|
||||
has_sync = m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
has_async = m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
|
||||
# Check runtime usage for each implementation
|
||||
sync_uses_runtime = (
|
||||
"runtime" in signature(m.modify_model_request).parameters if has_sync else False
|
||||
)
|
||||
# If async is implemented, check its signature; otherwise default is True
|
||||
# (the default async implementation always expects runtime)
|
||||
async_uses_runtime = (
|
||||
"runtime" in signature(m.amodify_model_request).parameters if has_async else True
|
||||
)
|
||||
|
||||
model_request_signatures.append(
|
||||
MiddlewareSignature(
|
||||
middleware=m,
|
||||
has_sync=has_sync,
|
||||
has_async=has_async,
|
||||
sync_uses_runtime=sync_uses_runtime,
|
||||
async_uses_runtime=async_uses_runtime,
|
||||
)
|
||||
)
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
@@ -365,11 +411,20 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
for sig in model_request_signatures:
|
||||
if sig.has_sync:
|
||||
if sig.sync_uses_runtime:
|
||||
sig.middleware.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
sig.middleware.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
elif sig.has_async:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
f'{sig.middleware.__class__.__name__}.amodify_model_request".'
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
@@ -393,11 +448,13 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
for sig in model_request_signatures:
|
||||
# If async is overridden and doesn't use runtime, call without it
|
||||
if sig.has_async and not sig.async_uses_runtime:
|
||||
await sig.middleware.amodify_model_request(request, state) # type: ignore[call-arg]
|
||||
# Otherwise call async with runtime (default implementation handles sync delegation)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
await sig.middleware.amodify_model_request(request, state, runtime)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
@@ -419,14 +476,46 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Add middleware nodes
|
||||
for m in middleware:
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model:
|
||||
if (
|
||||
m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_before = (
|
||||
m.before_model
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
else None
|
||||
)
|
||||
async_before = (
|
||||
m.abefore_model
|
||||
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
else None
|
||||
)
|
||||
before_node = RunnableCallable(sync_before, async_before)
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
|
||||
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
|
||||
)
|
||||
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model:
|
||||
if (
|
||||
m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_after = (
|
||||
m.after_model
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
else None
|
||||
)
|
||||
async_after = (
|
||||
m.aafter_model
|
||||
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
else None
|
||||
)
|
||||
after_node = RunnableCallable(sync_after, async_after)
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
|
||||
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
|
||||
)
|
||||
|
||||
# add start edge
|
||||
|
||||
@@ -109,7 +109,7 @@ def init_chat_model(
|
||||
Will attempt to infer model_provider from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
|
||||
- 'gpt-3...' | 'gpt-4...' | 'o1...' -> 'openai'
|
||||
- 'gpt-...' | 'o1...' | 'o3...' -> 'openai'
|
||||
- 'claude...' -> 'anthropic'
|
||||
- 'amazon....' -> 'bedrock'
|
||||
- 'gemini...' -> 'google_vertexai'
|
||||
@@ -474,7 +474,7 @@ _SUPPORTED_PROVIDERS = {
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> str | None:
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-3", "gpt-4", "o1", "o3")):
|
||||
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
|
||||
return "openai"
|
||||
if model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
|
||||
17
libs/langchain_v1/langchain/messages/__init__.py
Normal file
17
libs/langchain_v1/langchain/messages/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Message types."""
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AIMessage",
|
||||
"AIMessageChunk",
|
||||
"HumanMessage",
|
||||
"SystemMessage",
|
||||
"ToolMessage",
|
||||
]
|
||||
@@ -1,14 +1,9 @@
|
||||
import pytest
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from types import ModuleType
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
import warnings
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -18,31 +13,42 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool, InjectedToolCallId
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import InjectedToolCallId, tool
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel, Field
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.tools import InjectedState
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
ActionRequest,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.planning import (
|
||||
PlanningMiddleware,
|
||||
PlanningState,
|
||||
WRITE_TODOS_SYSTEM_PROMPT,
|
||||
write_todos,
|
||||
WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.types import Command
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from langchain.tools import InjectedState
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
@@ -1105,14 +1111,9 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1136,9 +1137,6 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
raise Exception("Model error")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1155,14 +1153,9 @@ def test_summarization_middleware_full_workflow() -> None:
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1423,3 +1416,481 @@ def test_jump_to_is_ephemeral() -> None:
|
||||
agent = agent.compile()
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "jump_to" not in result
|
||||
|
||||
|
||||
# Tests for PlanningMiddleware
|
||||
def test_planning_middleware_initialization() -> None:
|
||||
"""Test that PlanningMiddleware initializes correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
assert middleware.state_schema == PlanningState
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original_prompt,expected_prompt_prefix",
|
||||
[
|
||||
("Original prompt", "Original prompt\n\n## `write_todos`"),
|
||||
(None, "## `write_todos`"),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
|
||||
"""Test that modify_model_request handles system prompts correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=original_prompt,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"todos,expected_message",
|
||||
[
|
||||
([], "Updated todo list to []"),
|
||||
(
|
||||
[{"content": "Task 1", "status": "pending"}],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
{"content": "Task 3", "status": "completed"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
|
||||
"""Test that the write_todos tool executes correctly."""
|
||||
tool_call = {
|
||||
"args": {"todos": todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
result = write_todos.invoke(tool_call)
|
||||
assert result.update["todos"] == todos
|
||||
assert result.update["messages"][0].content == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_todos",
|
||||
[
|
||||
[{"content": "Task 1", "status": "invalid_status"}],
|
||||
[{"status": "pending"}],
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
|
||||
"""Test that the write_todos tool rejects invalid input."""
|
||||
tool_call = {
|
||||
"args": {"todos": invalid_todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
write_todos.invoke(tool_call)
|
||||
|
||||
|
||||
def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
"""Test that an agent can be created with the planning middleware."""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
middleware = PlanningMiddleware()
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
|
||||
|
||||
# human message (1)
|
||||
# ai message (2) - initial todo
|
||||
# tool message (3)
|
||||
# ai message (4) - updated todo
|
||||
# tool message (5)
|
||||
# ai message (6) - complete todo
|
||||
# tool message (7)
|
||||
# ai message (8) - no tool calls
|
||||
assert len(result["messages"]) == 8
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
|
||||
custom_system_prompt = "Custom todo system prompt for testing"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
|
||||
|
||||
def test_planning_middleware_custom_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom tool description."""
|
||||
custom_tool_description = "Custom tool description for testing"
|
||||
middleware = PlanningMiddleware(tool_description=custom_tool_description)
|
||||
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
|
||||
custom_system_prompt = "Custom system prompt"
|
||||
custom_tool_description = "Custom tool description"
|
||||
middleware = PlanningMiddleware(
|
||||
system_prompt=custom_system_prompt,
|
||||
tool_description=custom_tool_description,
|
||||
)
|
||||
|
||||
# Verify system prompt
|
||||
model = FakeToolCallingModel()
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt == custom_system_prompt
|
||||
|
||||
# Verify tool description
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_default_prompts() -> None:
|
||||
"""Test that PlanningMiddleware uses default prompts when none provided."""
|
||||
middleware = PlanningMiddleware()
|
||||
|
||||
# Verify default system prompt
|
||||
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
|
||||
|
||||
# Verify default tool description
|
||||
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that custom tool executes correctly in an agent."""
|
||||
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Custom task", "status": "pending"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
|
||||
# assert custom system prompt is in the first AI message
|
||||
assert "call the write_todos tool" in result["messages"][1].content
|
||||
|
||||
|
||||
# Async Middleware Tests
|
||||
async def test_create_agent_async_invoke() -> None:
|
||||
"""Test async invoke with async middleware hooks."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddleware(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddleware.amodify_model_request")
|
||||
request.messages.append(HumanMessage("async middleware message"))
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
|
||||
@tool
|
||||
def my_tool(input: str) -> str:
|
||||
"""A great tool"""
|
||||
calls.append("my_tool")
|
||||
return input.upper()
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"input": "yo"}, "id": "1", "name": "my_tool"}],
|
||||
[],
|
||||
]
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# Should have:
|
||||
# 1. Original hello message
|
||||
# 2. Async middleware message (first invoke)
|
||||
# 3. AI message with tool call
|
||||
# 4. Tool message
|
||||
# 5. Async middleware message (second invoke)
|
||||
# 6. Final AI message
|
||||
assert len(result["messages"]) == 6
|
||||
assert result["messages"][0].content == "hello"
|
||||
assert result["messages"][1].content == "async middleware message"
|
||||
assert calls == [
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
"my_tool",
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
]
|
||||
|
||||
|
||||
async def test_create_agent_async_invoke_multiple_middleware() -> None:
|
||||
"""Test async invoke with multiple async middleware hooks."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddlewareOne(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddlewareOne.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.aafter_model")
|
||||
|
||||
class AsyncMiddlewareTwo(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareTwo.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddlewareTwo.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareTwo.aafter_model")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
assert calls == [
|
||||
"AsyncMiddlewareOne.abefore_model",
|
||||
"AsyncMiddlewareTwo.abefore_model",
|
||||
"AsyncMiddlewareOne.amodify_model_request",
|
||||
"AsyncMiddlewareTwo.amodify_model_request",
|
||||
"AsyncMiddlewareTwo.aafter_model",
|
||||
"AsyncMiddlewareOne.aafter_model",
|
||||
]
|
||||
|
||||
|
||||
async def test_create_agent_async_jump() -> None:
|
||||
"""Test async invoke with async middleware using jump_to."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddlewareOne(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.abefore_model")
|
||||
|
||||
class AsyncMiddlewareTwo(AgentMiddleware):
|
||||
before_model_jump_to = ["end"]
|
||||
|
||||
async def abefore_model(self, state) -> dict[str, Any]:
|
||||
calls.append("AsyncMiddlewareTwo.abefore_model")
|
||||
return {"jump_to": "end"}
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": []})
|
||||
|
||||
assert result == {"messages": []}
|
||||
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]
|
||||
|
||||
|
||||
async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
"""Test async invoke with mixed sync and async middleware."""
|
||||
calls = []
|
||||
|
||||
class SyncMiddleware(AgentMiddleware):
|
||||
def before_model(self, state) -> None:
|
||||
calls.append("SyncMiddleware.before_model")
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("SyncMiddleware.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state) -> None:
|
||||
calls.append("SyncMiddleware.after_model")
|
||||
|
||||
class AsyncMiddleware(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddleware.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[SyncMiddleware(), AsyncMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In async mode, both sync and async middleware should work
|
||||
assert calls == [
|
||||
"SyncMiddleware.before_model",
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"SyncMiddleware.modify_model_request",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
"SyncMiddleware.after_model",
|
||||
]
|
||||
|
||||
|
||||
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None:
|
||||
"""Test that sync invoke with only async middleware raises TypeError."""
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncOnlyMiddleware()],
|
||||
).compile()
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r"No synchronous function provided for AsyncOnlyMiddleware\.amodify_model_request",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
|
||||
def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
|
||||
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
|
||||
calls = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
def before_model(self, state) -> None:
|
||||
calls.append("MixedMiddleware.before_model")
|
||||
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("MixedMiddleware.abefore_model")
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("MixedMiddleware.modify_model_request")
|
||||
return request
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("MixedMiddleware.amodify_model_request")
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[MixedMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In sync mode, only sync methods should be called
|
||||
assert calls == [
|
||||
"MixedMiddleware.before_model",
|
||||
"MixedMiddleware.modify_model_request",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
|
||||
|
||||
import pytest
|
||||
from typing import Any
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@@ -150,3 +151,161 @@ def test_decorators_use_function_names_as_default() -> None:
|
||||
assert my_before_hook.__class__.__name__ == "my_before_hook"
|
||||
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
|
||||
assert my_after_hook.__class__.__name__ == "my_after_hook"
|
||||
|
||||
|
||||
# Async Decorator Tests
|
||||
|
||||
|
||||
def test_async_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with async function."""
|
||||
|
||||
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
|
||||
async def async_before_model(state: CustomState) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_before_model, AgentMiddleware)
|
||||
assert async_before_model.state_schema == CustomState
|
||||
assert async_before_model.tools == [test_tool]
|
||||
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
|
||||
|
||||
|
||||
def test_async_after_model_decorator() -> None:
|
||||
"""Test after_model decorator with async function."""
|
||||
|
||||
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
|
||||
async def async_after_model(state: CustomState) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_after_model, AgentMiddleware)
|
||||
assert async_after_model.state_schema == CustomState
|
||||
assert async_after_model.tools == [test_tool]
|
||||
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
|
||||
|
||||
|
||||
def test_async_modify_model_request_decorator() -> None:
|
||||
"""Test modify_model_request decorator with async function."""
|
||||
|
||||
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
|
||||
async def async_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
|
||||
request.system_prompt = "Modified async"
|
||||
return request
|
||||
|
||||
assert isinstance(async_modify_request, AgentMiddleware)
|
||||
assert async_modify_request.state_schema == CustomState
|
||||
assert async_modify_request.tools == [test_tool]
|
||||
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
|
||||
|
||||
|
||||
def test_mixed_sync_async_decorators() -> None:
|
||||
"""Test decorators with both sync and async functions."""
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
def sync_before(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
async def async_before(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
return request
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
return request
|
||||
|
||||
# Both should create valid middleware instances
|
||||
assert isinstance(sync_before, AgentMiddleware)
|
||||
assert isinstance(async_before, AgentMiddleware)
|
||||
assert isinstance(sync_modify, AgentMiddleware)
|
||||
assert isinstance(async_modify, AgentMiddleware)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorators_integration() -> None:
|
||||
"""Test async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[track_async_before, track_async_modify, track_async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["async_before", "async_modify", "async_after"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_decorators_integration() -> None:
|
||||
"""Test mixed sync/async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
def track_sync_before(state: AgentState) -> None:
|
||||
call_order.append("sync_before")
|
||||
return None
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def track_sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("sync_modify")
|
||||
return request
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
@after_model
|
||||
def track_sync_after(state: AgentState) -> None:
|
||||
call_order.append("sync_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[
|
||||
track_sync_before,
|
||||
track_async_before,
|
||||
track_sync_modify,
|
||||
track_async_modify,
|
||||
track_async_after,
|
||||
track_sync_after,
|
||||
],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == [
|
||||
"sync_before",
|
||||
"async_before",
|
||||
"sync_modify",
|
||||
"async_modify",
|
||||
"sync_after",
|
||||
"async_after",
|
||||
]
|
||||
|
||||
@@ -267,6 +267,7 @@ def test_configurable_with_default() -> None:
|
||||
"stop_sequences": None,
|
||||
"anthropic_api_url": "https://api.anthropic.com",
|
||||
"anthropic_proxy": None,
|
||||
"context_management": None,
|
||||
"anthropic_api_key": SecretStr("bar"),
|
||||
"betas": None,
|
||||
"default_headers": None,
|
||||
|
||||
@@ -12,7 +12,7 @@ from operator import itemgetter
|
||||
from typing import Any, Callable, Literal, Optional, Union, cast
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -91,6 +91,7 @@ def _is_builtin_tool(tool: Any) -> bool:
|
||||
"web_search_",
|
||||
"web_fetch_",
|
||||
"code_execution_",
|
||||
"memory_",
|
||||
]
|
||||
return any(tool_type.startswith(prefix) for prefix in _builtin_tool_prefixes)
|
||||
|
||||
@@ -1193,6 +1194,25 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
Total tokens: 408
|
||||
|
||||
Context management:
|
||||
Anthropic supports a context editing feature that will automatically manage the
|
||||
model's context window (e.g., by clearing tool results).
|
||||
|
||||
See `Anthropic documentation <https://docs.claude.com/en/docs/build-with-claude/context-editing>`__
|
||||
for details and configuration options.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
betas=["context-management-2025-06-27"],
|
||||
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
|
||||
)
|
||||
llm_with_tools = llm.bind_tools([{"type": "web_search_20250305", "name": "web_search"}])
|
||||
response = llm_with_tools.invoke("Search for recent developments in AI")
|
||||
|
||||
Built-in tools:
|
||||
See LangChain `docs <https://python.langchain.com/docs/integrations/chat/anthropic/#built-in-tools>`__
|
||||
for more detail.
|
||||
@@ -1306,6 +1326,19 @@ class ChatAnthropic(BaseChatModel):
|
||||
'id': 'toolu_01VdNgt1YV7kGfj9LFLm6HyQ',
|
||||
'type': 'tool_call'}]
|
||||
|
||||
.. dropdown:: Memory tool
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
betas=["context-management-2025-06-27"],
|
||||
)
|
||||
llm_with_tools = llm.bind_tools([{"type": "memory_20250818", "name": "memory"}])
|
||||
response = llm_with_tools.invoke("What are my interests?")
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
@@ -1413,6 +1446,11 @@ class ChatAnthropic(BaseChatModel):
|
||||
"name": "example-mcp"}]``
|
||||
"""
|
||||
|
||||
context_management: Optional[dict[str, Any]] = None
|
||||
"""Configuration for
|
||||
`context management <https://docs.claude.com/en/docs/build-with-claude/context-editing>`__.
|
||||
"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
@@ -1565,6 +1603,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": stop or self.stop_sequences,
|
||||
"betas": self.betas,
|
||||
"context_management": self.context_management,
|
||||
"mcp_servers": self.mcp_servers,
|
||||
"system": system,
|
||||
**self.model_kwargs,
|
||||
@@ -2219,7 +2258,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
return llm | output_parser
|
||||
|
||||
@beta()
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -2234,8 +2272,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
kwargs: Additional keyword arguments are passed to the
|
||||
:meth:`~langchain_anthropic.chat_models.ChatAnthropic.bind` method.
|
||||
kwargs: Additional keyword arguments are passed to the Anthropic
|
||||
``messages.count_tokens`` method.
|
||||
|
||||
Basic usage:
|
||||
|
||||
@@ -2270,7 +2308,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
def get_weather(location: str) -> str:
|
||||
\"\"\"Get the current weather in a given location
|
||||
|
||||
Args:
|
||||
Args:
|
||||
location: The city and state, e.g. San Francisco, CA
|
||||
\"\"\"
|
||||
return "Sunny"
|
||||
@@ -2288,15 +2326,24 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
Uses Anthropic's `token counting API <https://docs.anthropic.com/en/docs/build-with-claude/token-counting>`__ to count tokens in messages.
|
||||
|
||||
""" # noqa: E501
|
||||
""" # noqa: D214,E501
|
||||
formatted_system, formatted_messages = _format_messages(messages)
|
||||
if isinstance(formatted_system, str):
|
||||
kwargs["system"] = formatted_system
|
||||
if tools:
|
||||
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
if self.context_management is not None:
|
||||
kwargs["context_management"] = self.context_management
|
||||
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
if self.betas is not None:
|
||||
beta_response = self._client.beta.messages.count_tokens(
|
||||
betas=self.betas,
|
||||
model=self.model,
|
||||
messages=formatted_messages, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
return beta_response.input_tokens
|
||||
response = self._client.messages.count_tokens(
|
||||
model=self.model,
|
||||
messages=formatted_messages, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
@@ -2409,7 +2456,7 @@ def _make_message_chunk_from_anthropic_event(
|
||||
# Capture model name, but don't include usage_metadata yet
|
||||
# as it will be properly reported in message_delta with complete info
|
||||
if hasattr(event.message, "model"):
|
||||
response_metadata = {"model_name": event.message.model}
|
||||
response_metadata: dict[str, Any] = {"model_name": event.message.model}
|
||||
else:
|
||||
response_metadata = {}
|
||||
|
||||
@@ -2510,13 +2557,16 @@ def _make_message_chunk_from_anthropic_event(
|
||||
# Process final usage metadata and completion info
|
||||
elif event.type == "message_delta" and stream_usage:
|
||||
usage_metadata = _create_usage_metadata(event.usage)
|
||||
response_metadata = {
|
||||
"stop_reason": event.delta.stop_reason,
|
||||
"stop_sequence": event.delta.stop_sequence,
|
||||
}
|
||||
if context_management := getattr(event, "context_management", None):
|
||||
response_metadata["context_management"] = context_management.model_dump()
|
||||
message_chunk = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={
|
||||
"stop_reason": event.delta.stop_reason,
|
||||
"stop_sequence": event.delta.stop_sequence,
|
||||
},
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
# Unhandled event types (e.g., `content_block_stop`, `ping` events)
|
||||
# https://docs.anthropic.com/en/docs/build-with-claude/streaming#other-events
|
||||
|
||||
@@ -7,12 +7,12 @@ authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9.0,<4.0.0"
|
||||
dependencies = [
|
||||
"anthropic>=0.67.0,<1.0.0",
|
||||
"anthropic>=0.69.0,<1.0.0",
|
||||
"langchain-core>=0.3.76,<2.0.0",
|
||||
"pydantic>=2.7.4,<3.0.0",
|
||||
]
|
||||
name = "langchain-anthropic"
|
||||
version = "0.3.20"
|
||||
version = "0.3.21"
|
||||
description = "An integration package connecting Anthropic and LangChain"
|
||||
readme = "README.md"
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1485,6 +1485,50 @@ def test_search_result_top_level() -> None:
|
||||
assert any("citations" in block for block in result.content)
|
||||
|
||||
|
||||
def test_memory_tool() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
|
||||
betas=["context-management-2025-06-27"],
|
||||
)
|
||||
llm_with_tools = llm.bind_tools([{"type": "memory_20250818", "name": "memory"}])
|
||||
response = llm_with_tools.invoke("What are my interests?")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert response.tool_calls
|
||||
assert response.tool_calls[0]["name"] == "memory"
|
||||
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_context_management() -> None:
|
||||
# TODO: update example to trigger action
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
|
||||
betas=["context-management-2025-06-27"],
|
||||
context_management={
|
||||
"edits": [
|
||||
{
|
||||
"type": "clear_tool_uses_20250919",
|
||||
"trigger": {"type": "input_tokens", "value": 10},
|
||||
"clear_at_least": {"type": "input_tokens", "value": 5},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "web_search_20250305", "name": "web_search"}]
|
||||
)
|
||||
input_message = {"role": "user", "content": "Search for recent developments in AI"}
|
||||
response = llm_with_tools.invoke([input_message])
|
||||
assert response.response_metadata.get("context_management")
|
||||
|
||||
# Test streaming
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm_with_tools.stream([input_message]):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("context_management")
|
||||
|
||||
|
||||
def test_async_shared_client() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest") # type: ignore[call-arg]
|
||||
_ = asyncio.run(llm.ainvoke("Hello"))
|
||||
|
||||
@@ -1065,9 +1065,21 @@ def test_get_num_tokens_from_messages_passes_kwargs() -> None:
|
||||
with patch.object(anthropic, "Client") as _client:
|
||||
llm.get_num_tokens_from_messages([HumanMessage("foo")], foo="bar")
|
||||
|
||||
assert (
|
||||
_client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar"
|
||||
assert _client.return_value.messages.count_tokens.call_args.kwargs["foo"] == "bar"
|
||||
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
betas=["context-management-2025-06-27"],
|
||||
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
|
||||
)
|
||||
with patch.object(anthropic, "Client") as _client:
|
||||
llm.get_num_tokens_from_messages([HumanMessage("foo")])
|
||||
|
||||
call_args = _client.return_value.beta.messages.count_tokens.call_args.kwargs
|
||||
assert call_args["betas"] == ["context-management-2025-06-27"]
|
||||
assert call_args["context_management"] == {
|
||||
"edits": [{"type": "clear_tool_uses_20250919"}]
|
||||
}
|
||||
|
||||
|
||||
def test_usage_metadata_standardization() -> None:
|
||||
@@ -1217,6 +1229,22 @@ def test_cache_control_kwarg() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_context_management_in_payload() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-sonnet-4-5-20250929", # type: ignore[call-arg]
|
||||
betas=["context-management-2025-06-27"],
|
||||
context_management={"edits": [{"type": "clear_tool_uses_20250919"}]},
|
||||
)
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[{"type": "web_search_20250305", "name": "web_search"}]
|
||||
)
|
||||
input_message = HumanMessage("Search for recent developments in AI")
|
||||
payload = llm_with_tools._get_request_payload([input_message]) # type: ignore[attr-defined]
|
||||
assert payload["context_management"] == {
|
||||
"edits": [{"type": "clear_tool_uses_20250919"}]
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_model_params() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
|
||||
|
||||
|
||||
22
libs/partners/anthropic/uv.lock
generated
22
libs/partners/anthropic/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.9.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
|
||||
@@ -21,20 +21,21 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.67.0"
|
||||
version = "0.69.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "distro" },
|
||||
{ name = "docstring-parser" },
|
||||
{ name = "httpx" },
|
||||
{ name = "jiter" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/08/ee91464cd821e6fca52d9a23be44815c95edd3c1cf1e844b2c5e85f0d57f/anthropic-0.67.0.tar.gz", hash = "sha256:d1531b210ea300c73423141d29bcee20fcd24ef9f426f6437c0a5d93fc98fb8e", size = 441639, upload-time = "2025-09-10T14:47:18.137Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c8/9d/9ad1778b95f15c5b04e7d328c1b5f558f1e893857b7c33cd288c19c0057a/anthropic-0.69.0.tar.gz", hash = "sha256:c604d287f4d73640f40bd2c0f3265a2eb6ce034217ead0608f6b07a8bc5ae5f2", size = 480622, upload-time = "2025-09-29T16:53:45.282Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/9d/9adbda372710918cc8271d089a2ceae4d977a125f90bc3c4b456bca4f281/anthropic-0.67.0-py3-none-any.whl", hash = "sha256:f80a81ec1132c514215f33d25edeeab1c4691ad5361b391ebb70d528b0605b55", size = 317126, upload-time = "2025-09-10T14:47:16.351Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/38/75129688de5637eb5b383e5f2b1570a5cc3aecafa4de422da8eea4b90a6c/anthropic-0.69.0-py3-none-any.whl", hash = "sha256:1f73193040f33f11e27c2cd6ec25f24fe7c3f193dc1c5cde6b7a08b18a16bcc5", size = 337265, upload-time = "2025-09-29T16:53:43.686Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -257,6 +258,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.17.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
@@ -455,7 +465,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-anthropic"
|
||||
version = "0.3.20"
|
||||
version = "0.3.21"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "anthropic" },
|
||||
@@ -497,7 +507,7 @@ typing = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anthropic", specifier = ">=0.67.0,<1.0.0" },
|
||||
{ name = "anthropic", specifier = ">=0.69.0,<1.0.0" },
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
]
|
||||
|
||||
@@ -85,6 +85,7 @@ ignore = [
|
||||
"RUF001",
|
||||
"ERA001",
|
||||
"PLR0911",
|
||||
"FA100", # from __future__ import annotations breaks some schema conversion logic
|
||||
|
||||
# TODO
|
||||
"PLR2004", # Comparison to magic number
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test ChatOpenAI chat model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test Responses API usage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Annotated, Any, Literal, Optional, cast
|
||||
|
||||
Reference in New Issue
Block a user