From eed0f6c2894d680a1e997828adcedbb604a8dcdc Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Mon, 29 Sep 2025 19:23:26 -0700 Subject: [PATCH] feat(langchain): todo middleware (#33152) Porting the [planning middleware](https://github.com/langchain-ai/deepagents/blob/39c0138d0fdb2f8dc6a02606931cac2ffb777aff/src/deepagents/middleware.py#L21) over from deepagents. Also adding the ability to configure: * System prompt * Tool description ```py from langchain.agents.middleware.planning import PlanningMiddleware from langchain.agents import create_agent agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()]) result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]}) print(result["todos"]) # Array of todo items with status tracking ``` --- .../langchain/agents/middleware/__init__.py | 2 + .../langchain/agents/middleware/planning.py | 197 ++++++++++++ .../agents/test_middleware_agent.py | 302 ++++++++++++++++-- 3 files changed, 469 insertions(+), 32 deletions(-) create mode 100644 libs/langchain_v1/langchain/agents/middleware/planning.py diff --git a/libs/langchain_v1/langchain/agents/middleware/__init__.py b/libs/langchain_v1/langchain/agents/middleware/__init__.py index 014989516ae..64d06e85445 100644 --- a/libs/langchain_v1/langchain/agents/middleware/__init__.py +++ b/libs/langchain_v1/langchain/agents/middleware/__init__.py @@ -1,6 +1,7 @@ """Middleware plugins for agents.""" from .human_in_the_loop import HumanInTheLoopMiddleware +from .planning import PlanningMiddleware from .prompt_caching import AnthropicPromptCachingMiddleware from .summarization import SummarizationMiddleware from .types import AgentMiddleware, AgentState, ModelRequest @@ -12,5 +13,6 @@ __all__ = [ "AnthropicPromptCachingMiddleware", "HumanInTheLoopMiddleware", "ModelRequest", + "PlanningMiddleware", "SummarizationMiddleware", ] diff --git a/libs/langchain_v1/langchain/agents/middleware/planning.py b/libs/langchain_v1/langchain/agents/middleware/planning.py new file mode 100644 index 00000000000..5fb451dcf55 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/planning.py @@ -0,0 +1,197 @@ +"""Planning and task management middleware for agents.""" +# ruff: noqa: E501 + +from __future__ import annotations + +from typing import Annotated, Literal + +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.types import Command +from typing_extensions import NotRequired, TypedDict + +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest +from langchain.tools import InjectedToolCallId + + +class Todo(TypedDict): + """A single todo item with content and status.""" + + content: str + """The content/description of the todo item.""" + + status: Literal["pending", "in_progress", "completed"] + """The current status of the todo item.""" + + +class PlanningState(AgentState): + """State schema for the todo middleware.""" + + todos: NotRequired[list[Todo]] + """List of todo items for tracking task progress.""" + + +WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user. + +Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly. + +## When to Use This Tool +Use this tool in these scenarios: + +1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions +2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations +3. User explicitly requests todo list - When the user directly asks you to use the todo list +4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated) +5. The plan may need future revisions or updates based on results from the first few steps + +## How to Use This Tool +1. When you start working on a task - Mark it as in_progress BEFORE beginning work. +2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation. +3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks. +4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress. + +## When NOT to Use This Tool +It is important to skip using this tool when: +1. There is only a single, straightforward task +2. The task is trivial and tracking it provides no benefit +3. The task can be completed in less than 3 trivial steps +4. The task is purely conversational or informational + +## Task States and Management + +1. **Task States**: Use these states to track progress: + - pending: Task not yet started + - in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel) + - completed: Task finished successfully + +2. **Task Management**: + - Update task status in real-time as you work + - Mark tasks complete IMMEDIATELY after finishing (don't batch completions) + - Complete current tasks before starting new ones + - Remove tasks that are no longer relevant from the list entirely + - IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!. + - IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something. + +3. **Task Completion Requirements**: + - ONLY mark a task as completed when you have FULLY accomplished it + - If you encounter errors, blockers, or cannot finish, keep the task as in_progress + - When blocked, create a new task describing what needs to be resolved + - Never mark a task as completed if: + - There are unresolved issues or errors + - Work is partial or incomplete + - You encountered blockers that prevent completion + - You couldn't find necessary resources or dependencies + - Quality standards haven't been met + +4. **Task Breakdown**: + - Create specific, actionable items + - Break complex tasks into smaller, manageable steps + - Use clear, descriptive task names + +Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully +Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all.""" + +WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos` + +You have access to the `write_todos` tool to help you manage and plan complex objectives. +Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress. +This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps. + +It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed. +For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool. +Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests. + +## Important To-Do List Usage Notes to Remember +- The `write_todos` tool should never be called multiple times in parallel. +- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant.""" + + +@tool(description=WRITE_TODOS_TOOL_DESCRIPTION) +def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command: + """Create and manage a structured task list for your current work session.""" + return Command( + update={ + "todos": todos, + "messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)], + } + ) + + +class PlanningMiddleware(AgentMiddleware): + """Middleware that provides todo list management capabilities to agents. + + This middleware adds a `write_todos` tool that allows agents to create and manage + structured task lists for complex multi-step operations. It's designed to help + agents track progress, organize complex tasks, and provide users with visibility + into task completion status. + + The middleware automatically injects system prompts that guide the agent on when + and how to use the todo functionality effectively. + + Example: + ```python + from langchain.agents.middleware.planning import PlanningMiddleware + from langchain.agents import create_agent + + agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()]) + + # Agent now has access to write_todos tool and todo state tracking + result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]}) + + print(result["todos"]) # Array of todo items with status tracking + ``` + + Args: + system_prompt: Custom system prompt to guide the agent on using the todo tool. + If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``. + tool_description: Custom description for the write_todos tool. + If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``. + """ + + state_schema = PlanningState + + def __init__( + self, + *, + system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT, + tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION, + ) -> None: + """Initialize the PlanningMiddleware with optional custom prompts. + + Args: + system_prompt: Custom system prompt to guide the agent on using the todo tool. + tool_description: Custom description for the write_todos tool. + """ + super().__init__() + self.system_prompt = system_prompt + self.tool_description = tool_description + + # Dynamically create the write_todos tool with the custom description + @tool(description=self.tool_description) + def write_todos( + todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId] + ) -> Command: + """Create and manage a structured task list for your current work session.""" + return Command( + update={ + "todos": todos, + "messages": [ + ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id) + ], + } + ) + + self.tools = [write_todos] + + def modify_model_request( # type: ignore[override] + self, + request: ModelRequest, + state: PlanningState, # noqa: ARG002 + ) -> ModelRequest: + """Update the system prompt to include the todo system prompt.""" + request.system_prompt = ( + request.system_prompt + "\n\n" + self.system_prompt + if request.system_prompt + else self.system_prompt + ) + return request diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 519c99e9a0d..62d4addccb2 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -1,14 +1,9 @@ -import pytest +import warnings +from types import ModuleType from typing import Any from unittest.mock import patch -from types import ModuleType -from syrupy.assertion import SnapshotAssertion - -import warnings -from langgraph.runtime import Runtime -from typing_extensions import Annotated -from pydantic import BaseModel, Field +import pytest from langchain_core.language_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( @@ -18,31 +13,42 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) -from langchain_core.tools import tool, InjectedToolCallId +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import InjectedToolCallId, tool +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.constants import END +from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import Runtime +from langgraph.types import Command +from pydantic import BaseModel, Field +from syrupy.assertion import SnapshotAssertion +from typing_extensions import Annotated -from langchain.agents.middleware_agent import create_agent -from langchain.tools import InjectedState from langchain.agents.middleware.human_in_the_loop import ( - HumanInTheLoopMiddleware, ActionRequest, + HumanInTheLoopMiddleware, +) +from langchain.agents.middleware.planning import ( + PlanningMiddleware, + PlanningState, + WRITE_TODOS_SYSTEM_PROMPT, + write_todos, + WRITE_TODOS_TOOL_DESCRIPTION, ) from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware from langchain.agents.middleware.types import ( AgentMiddleware, - ModelRequest, AgentState, + ModelRequest, OmitFromInput, OmitFromOutput, PrivateStateAttr, ) - -from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.checkpoint.memory import InMemorySaver -from langgraph.constants import END -from langgraph.graph.message import REMOVE_ALL_MESSAGES -from langgraph.types import Command +from langchain.agents.middleware_agent import create_agent from langchain.agents.structured_output import ToolStrategy +from langchain.tools import InjectedState from .messages import _AnyIdHumanMessage, _AnyIdToolMessage from .model import FakeToolCallingModel @@ -1105,14 +1111,9 @@ def test_summarization_middleware_summary_creation() -> None: class MockModel(BaseChatModel): def invoke(self, prompt): - from langchain_core.messages import AIMessage - return AIMessage(content="Generated summary") def _generate(self, messages, **kwargs): - from langchain_core.outputs import ChatResult, ChatGeneration - from langchain_core.messages import AIMessage - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property @@ -1136,9 +1137,6 @@ def test_summarization_middleware_summary_creation() -> None: raise Exception("Model error") def _generate(self, messages, **kwargs): - from langchain_core.outputs import ChatResult, ChatGeneration - from langchain_core.messages import AIMessage - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property @@ -1155,14 +1153,9 @@ def test_summarization_middleware_full_workflow() -> None: class MockModel(BaseChatModel): def invoke(self, prompt): - from langchain_core.messages import AIMessage - return AIMessage(content="Generated summary") def _generate(self, messages, **kwargs): - from langchain_core.outputs import ChatResult, ChatGeneration - from langchain_core.messages import AIMessage - return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property @@ -1423,3 +1416,248 @@ def test_jump_to_is_ephemeral() -> None: agent = agent.compile() result = agent.invoke({"messages": [HumanMessage("Hello")]}) assert "jump_to" not in result + + +# Tests for PlanningMiddleware +def test_planning_middleware_initialization() -> None: + """Test that PlanningMiddleware initializes correctly.""" + middleware = PlanningMiddleware() + assert middleware.state_schema == PlanningState + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "write_todos" + + +@pytest.mark.parametrize( + "original_prompt,expected_prompt_prefix", + [ + ("Original prompt", "Original prompt\n\n## `write_todos`"), + (None, "## `write_todos`"), + ], +) +def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None: + """Test that modify_model_request handles system prompts correctly.""" + middleware = PlanningMiddleware() + model = FakeToolCallingModel() + + request = ModelRequest( + model=model, + system_prompt=original_prompt, + messages=[HumanMessage(content="Hello")], + tool_choice=None, + tools=[], + response_format=None, + model_settings={}, + ) + + state: PlanningState = {"messages": [HumanMessage(content="Hello")]} + modified_request = middleware.modify_model_request(request, state) + assert modified_request.system_prompt.startswith(expected_prompt_prefix) + + +@pytest.mark.parametrize( + "todos,expected_message", + [ + ([], "Updated todo list to []"), + ( + [{"content": "Task 1", "status": "pending"}], + "Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]", + ), + ( + [ + {"content": "Task 1", "status": "pending"}, + {"content": "Task 2", "status": "in_progress"}, + ], + "Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]", + ), + ( + [ + {"content": "Task 1", "status": "pending"}, + {"content": "Task 2", "status": "in_progress"}, + {"content": "Task 3", "status": "completed"}, + ], + "Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]", + ), + ], +) +def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None: + """Test that the write_todos tool executes correctly.""" + tool_call = { + "args": {"todos": todos}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + result = write_todos.invoke(tool_call) + assert result.update["todos"] == todos + assert result.update["messages"][0].content == expected_message + + +@pytest.mark.parametrize( + "invalid_todos", + [ + [{"content": "Task 1", "status": "invalid_status"}], + [{"status": "pending"}], + ], +) +def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None: + """Test that the write_todos tool rejects invalid input.""" + tool_call = { + "args": {"todos": invalid_todos}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + with pytest.raises(Exception): + write_todos.invoke(tool_call) + + +def test_planning_middleware_agent_creation_with_middleware() -> None: + """Test that an agent can be created with the planning middleware.""" + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "args": {"todos": [{"content": "Task 1", "status": "pending"}]}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + ], + [ + { + "args": {"todos": [{"content": "Task 1", "status": "in_progress"}]}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + ], + [ + { + "args": {"todos": [{"content": "Task 1", "status": "completed"}]}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + ], + [], + ] + ) + middleware = PlanningMiddleware() + agent = create_agent(model=model, middleware=[middleware]) + agent = agent.compile() + + result = agent.invoke({"messages": [HumanMessage("Hello")]}) + assert result["todos"] == [{"content": "Task 1", "status": "completed"}] + + # human message (1) + # ai message (2) - initial todo + # tool message (3) + # ai message (4) - updated todo + # tool message (5) + # ai message (6) - complete todo + # tool message (7) + # ai message (8) - no tool calls + assert len(result["messages"]) == 8 + + +def test_planning_middleware_custom_system_prompt() -> None: + """Test that PlanningMiddleware can be initialized with custom system prompt.""" + custom_system_prompt = "Custom todo system prompt for testing" + middleware = PlanningMiddleware(system_prompt=custom_system_prompt) + model = FakeToolCallingModel() + + request = ModelRequest( + model=model, + system_prompt="Original prompt", + messages=[HumanMessage(content="Hello")], + tool_choice=None, + tools=[], + response_format=None, + model_settings={}, + ) + + state: PlanningState = {"messages": [HumanMessage(content="Hello")]} + modified_request = middleware.modify_model_request(request, state) + assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}" + + +def test_planning_middleware_custom_tool_description() -> None: + """Test that PlanningMiddleware can be initialized with custom tool description.""" + custom_tool_description = "Custom tool description for testing" + middleware = PlanningMiddleware(tool_description=custom_tool_description) + + assert len(middleware.tools) == 1 + tool = middleware.tools[0] + assert tool.description == custom_tool_description + + +def test_planning_middleware_custom_system_prompt_and_tool_description() -> None: + """Test that PlanningMiddleware can be initialized with both custom prompts.""" + custom_system_prompt = "Custom system prompt" + custom_tool_description = "Custom tool description" + middleware = PlanningMiddleware( + system_prompt=custom_system_prompt, + tool_description=custom_tool_description, + ) + + # Verify system prompt + model = FakeToolCallingModel() + request = ModelRequest( + model=model, + system_prompt=None, + messages=[HumanMessage(content="Hello")], + tool_choice=None, + tools=[], + response_format=None, + model_settings={}, + ) + + state: PlanningState = {"messages": [HumanMessage(content="Hello")]} + modified_request = middleware.modify_model_request(request, state) + assert modified_request.system_prompt == custom_system_prompt + + # Verify tool description + assert len(middleware.tools) == 1 + tool = middleware.tools[0] + assert tool.description == custom_tool_description + + +def test_planning_middleware_default_prompts() -> None: + """Test that PlanningMiddleware uses default prompts when none provided.""" + middleware = PlanningMiddleware() + + # Verify default system prompt + assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT + + # Verify default tool description + assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION + assert len(middleware.tools) == 1 + tool = middleware.tools[0] + assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION + + +def test_planning_middleware_custom_system_prompt() -> None: + """Test that custom tool executes correctly in an agent.""" + middleware = PlanningMiddleware(system_prompt="call the write_todos tool") + + model = FakeToolCallingModel( + tool_calls=[ + [ + { + "args": {"todos": [{"content": "Custom task", "status": "pending"}]}, + "name": "write_todos", + "type": "tool_call", + "id": "test_call", + } + ], + [], + ] + ) + + agent = create_agent(model=model, middleware=[middleware]) + agent = agent.compile() + + result = agent.invoke({"messages": [HumanMessage("Hello")]}) + assert result["todos"] == [{"content": "Custom task", "status": "pending"}] + # assert custom system prompt is in the first AI message + assert "call the write_todos tool" in result["messages"][1].content