feat(langchain): todo middleware (#33152)

Porting the [planning
middleware](39c0138d0f/src/deepagents/middleware.py (L21))
over from deepagents.

Also adding the ability to configure:
* System prompt
* Tool description

```py
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents import create_agent

agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])

result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})

print(result["todos"])  # Array of todo items with status tracking
```
This commit is contained in:
Sydney Runkle
2025-09-29 19:23:26 -07:00
committed by GitHub
parent 729637a347
commit eed0f6c289
3 changed files with 469 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
"""Middleware plugins for agents."""
from .human_in_the_loop import HumanInTheLoopMiddleware
from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .types import AgentMiddleware, AgentState, ModelRequest
@@ -12,5 +13,6 @@ __all__ = [
"AnthropicPromptCachingMiddleware",
"HumanInTheLoopMiddleware",
"ModelRequest",
"PlanningMiddleware",
"SummarizationMiddleware",
]

View File

@@ -0,0 +1,197 @@
"""Planning and task management middleware for agents."""
# ruff: noqa: E501
from __future__ import annotations
from typing import Annotated, Literal
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.tools import InjectedToolCallId
class Todo(TypedDict):
"""A single todo item with content and status."""
content: str
"""The content/description of the todo item."""
status: Literal["pending", "in_progress", "completed"]
"""The current status of the todo item."""
class PlanningState(AgentState):
"""State schema for the todo middleware."""
todos: NotRequired[list[Todo]]
"""List of todo items for tracking task progress."""
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
## When to Use This Tool
Use this tool in these scenarios:
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
3. User explicitly requests todo list - When the user directly asks you to use the todo list
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
5. The plan may need future revisions or updates based on results from the first few steps
## How to Use This Tool
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
## When NOT to Use This Tool
It is important to skip using this tool when:
1. There is only a single, straightforward task
2. The task is trivial and tracking it provides no benefit
3. The task can be completed in less than 3 trivial steps
4. The task is purely conversational or informational
## Task States and Management
1. **Task States**: Use these states to track progress:
- pending: Task not yet started
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
- completed: Task finished successfully
2. **Task Management**:
- Update task status in real-time as you work
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
- Complete current tasks before starting new ones
- Remove tasks that are no longer relevant from the list entirely
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
3. **Task Completion Requirements**:
- ONLY mark a task as completed when you have FULLY accomplished it
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
- When blocked, create a new task describing what needs to be resolved
- Never mark a task as completed if:
- There are unresolved issues or errors
- Work is partial or incomplete
- You encountered blockers that prevent completion
- You couldn't find necessary resources or dependencies
- Quality standards haven't been met
4. **Task Breakdown**:
- Create specific, actionable items
- Break complex tasks into smaller, manageable steps
- Use clear, descriptive task names
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
You have access to the `write_todos` tool to help you manage and plan complex objectives.
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
## Important To-Do List Usage Notes to Remember
- The `write_todos` tool should never be called multiple times in parallel.
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
}
)
class PlanningMiddleware(AgentMiddleware):
"""Middleware that provides todo list management capabilities to agents.
This middleware adds a `write_todos` tool that allows agents to create and manage
structured task lists for complex multi-step operations. It's designed to help
agents track progress, organize complex tasks, and provide users with visibility
into task completion status.
The middleware automatically injects system prompts that guide the agent on when
and how to use the todo functionality effectively.
Example:
```python
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
# Agent now has access to write_todos tool and todo state tracking
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
print(result["todos"]) # Array of todo items with status tracking
```
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
tool_description: Custom description for the write_todos tool.
If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
"""
state_schema = PlanningState
def __init__(
self,
*,
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
) -> None:
"""Initialize the PlanningMiddleware with optional custom prompts.
Args:
system_prompt: Custom system prompt to guide the agent on using the todo tool.
tool_description: Custom description for the write_todos tool.
"""
super().__init__()
self.system_prompt = system_prompt
self.tool_description = tool_description
# Dynamically create the write_todos tool with the custom description
@tool(description=self.tool_description)
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
],
}
)
self.tools = [write_todos]
def modify_model_request( # type: ignore[override]
self,
request: ModelRequest,
state: PlanningState, # noqa: ARG002
) -> ModelRequest:
"""Update the system prompt to include the todo system prompt."""
request.system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return request

View File

@@ -1,14 +1,9 @@
import pytest
import warnings
from types import ModuleType
from typing import Any
from unittest.mock import patch
from types import ModuleType
from syrupy.assertion import SnapshotAssertion
import warnings
from langgraph.runtime import Runtime
from typing_extensions import Annotated
from pydantic import BaseModel, Field
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
@@ -18,31 +13,42 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.tools import tool, InjectedToolCallId
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import InjectedToolCallId, tool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from langgraph.types import Command
from pydantic import BaseModel, Field
from syrupy.assertion import SnapshotAssertion
from typing_extensions import Annotated
from langchain.agents.middleware_agent import create_agent
from langchain.tools import InjectedState
from langchain.agents.middleware.human_in_the_loop import (
HumanInTheLoopMiddleware,
ActionRequest,
HumanInTheLoopMiddleware,
)
from langchain.agents.middleware.planning import (
PlanningMiddleware,
PlanningState,
WRITE_TODOS_SYSTEM_PROMPT,
write_todos,
WRITE_TODOS_TOOL_DESCRIPTION,
)
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
AgentState,
ModelRequest,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
)
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent
from langchain.agents.structured_output import ToolStrategy
from langchain.tools import InjectedState
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
from .model import FakeToolCallingModel
@@ -1105,14 +1111,9 @@ def test_summarization_middleware_summary_creation() -> None:
class MockModel(BaseChatModel):
def invoke(self, prompt):
from langchain_core.messages import AIMessage
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1136,9 +1137,6 @@ def test_summarization_middleware_summary_creation() -> None:
raise Exception("Model error")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1155,14 +1153,9 @@ def test_summarization_middleware_full_workflow() -> None:
class MockModel(BaseChatModel):
def invoke(self, prompt):
from langchain_core.messages import AIMessage
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -1423,3 +1416,248 @@ def test_jump_to_is_ephemeral() -> None:
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert "jump_to" not in result
# Tests for PlanningMiddleware
def test_planning_middleware_initialization() -> None:
"""Test that PlanningMiddleware initializes correctly."""
middleware = PlanningMiddleware()
assert middleware.state_schema == PlanningState
assert len(middleware.tools) == 1
assert middleware.tools[0].name == "write_todos"
@pytest.mark.parametrize(
"original_prompt,expected_prompt_prefix",
[
("Original prompt", "Original prompt\n\n## `write_todos`"),
(None, "## `write_todos`"),
],
)
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
"""Test that modify_model_request handles system prompts correctly."""
middleware = PlanningMiddleware()
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=original_prompt,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
@pytest.mark.parametrize(
"todos,expected_message",
[
([], "Updated todo list to []"),
(
[{"content": "Task 1", "status": "pending"}],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
),
(
[
{"content": "Task 1", "status": "pending"},
{"content": "Task 2", "status": "in_progress"},
{"content": "Task 3", "status": "completed"},
],
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
),
],
)
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
"""Test that the write_todos tool executes correctly."""
tool_call = {
"args": {"todos": todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
result = write_todos.invoke(tool_call)
assert result.update["todos"] == todos
assert result.update["messages"][0].content == expected_message
@pytest.mark.parametrize(
"invalid_todos",
[
[{"content": "Task 1", "status": "invalid_status"}],
[{"status": "pending"}],
],
)
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
"""Test that the write_todos tool rejects invalid input."""
tool_call = {
"args": {"todos": invalid_todos},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
with pytest.raises(Exception):
write_todos.invoke(tool_call)
def test_planning_middleware_agent_creation_with_middleware() -> None:
"""Test that an agent can be created with the planning middleware."""
model = FakeToolCallingModel(
tool_calls=[
[
{
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[
{
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[],
]
)
middleware = PlanningMiddleware()
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
# human message (1)
# ai message (2) - initial todo
# tool message (3)
# ai message (4) - updated todo
# tool message (5)
# ai message (6) - complete todo
# tool message (7)
# ai message (8) - no tool calls
assert len(result["messages"]) == 8
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
custom_system_prompt = "Custom todo system prompt for testing"
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt="Original prompt",
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
def test_planning_middleware_custom_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with custom tool description."""
custom_tool_description = "Custom tool description for testing"
middleware = PlanningMiddleware(tool_description=custom_tool_description)
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
custom_system_prompt = "Custom system prompt"
custom_tool_description = "Custom tool description"
middleware = PlanningMiddleware(
system_prompt=custom_system_prompt,
tool_description=custom_tool_description,
)
# Verify system prompt
model = FakeToolCallingModel()
request = ModelRequest(
model=model,
system_prompt=None,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
modified_request = middleware.modify_model_request(request, state)
assert modified_request.system_prompt == custom_system_prompt
# Verify tool description
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == custom_tool_description
def test_planning_middleware_default_prompts() -> None:
"""Test that PlanningMiddleware uses default prompts when none provided."""
middleware = PlanningMiddleware()
# Verify default system prompt
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
# Verify default tool description
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
assert len(middleware.tools) == 1
tool = middleware.tools[0]
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
def test_planning_middleware_custom_system_prompt() -> None:
"""Test that custom tool executes correctly in an agent."""
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
model = FakeToolCallingModel(
tool_calls=[
[
{
"args": {"todos": [{"content": "Custom task", "status": "pending"}]},
"name": "write_todos",
"type": "tool_call",
"id": "test_call",
}
],
[],
]
)
agent = create_agent(model=model, middleware=[middleware])
agent = agent.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
# assert custom system prompt is in the first AI message
assert "call the write_todos tool" in result["messages"][1].content