mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
feat(langchain): todo middleware (#33152)
Porting the [planning
middleware](39c0138d0f/src/deepagents/middleware.py (L21))
over from deepagents.
Also adding the ability to configure:
* System prompt
* Tool description
```py
from langchain.agents.middleware.planning import PlanningMiddleware
from langchain.agents import create_agent
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
print(result["todos"]) # Array of todo items with status tracking
```
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Middleware plugins for agents."""
|
||||
|
||||
from .human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .types import AgentMiddleware, AgentState, ModelRequest
|
||||
@@ -12,5 +13,6 @@ __all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"ModelRequest",
|
||||
"PlanningMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
]
|
||||
|
||||
197
libs/langchain_v1/langchain/agents/middleware/planning.py
Normal file
197
libs/langchain_v1/langchain/agents/middleware/planning.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Planning and task management middleware for agents."""
|
||||
# ruff: noqa: E501
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.tools import InjectedToolCallId
|
||||
|
||||
|
||||
class Todo(TypedDict):
|
||||
"""A single todo item with content and status."""
|
||||
|
||||
content: str
|
||||
"""The content/description of the todo item."""
|
||||
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
"""The current status of the todo item."""
|
||||
|
||||
|
||||
class PlanningState(AgentState):
|
||||
"""State schema for the todo middleware."""
|
||||
|
||||
todos: NotRequired[list[Todo]]
|
||||
"""List of todo items for tracking task progress."""
|
||||
|
||||
|
||||
WRITE_TODOS_TOOL_DESCRIPTION = """Use this tool to create and manage a structured task list for your current work session. This helps you track progress, organize complex tasks, and demonstrate thoroughness to the user.
|
||||
|
||||
Only use this tool if you think it will be helpful in staying organized. If the user's request is trivial and takes less than 3 steps, it is better to NOT use this tool and just do the task directly.
|
||||
|
||||
## When to Use This Tool
|
||||
Use this tool in these scenarios:
|
||||
|
||||
1. Complex multi-step tasks - When a task requires 3 or more distinct steps or actions
|
||||
2. Non-trivial and complex tasks - Tasks that require careful planning or multiple operations
|
||||
3. User explicitly requests todo list - When the user directly asks you to use the todo list
|
||||
4. User provides multiple tasks - When users provide a list of things to be done (numbered or comma-separated)
|
||||
5. The plan may need future revisions or updates based on results from the first few steps
|
||||
|
||||
## How to Use This Tool
|
||||
1. When you start working on a task - Mark it as in_progress BEFORE beginning work.
|
||||
2. After completing a task - Mark it as completed and add any new follow-up tasks discovered during implementation.
|
||||
3. You can also update future tasks, such as deleting them if they are no longer necessary, or adding new tasks that are necessary. Don't change previously completed tasks.
|
||||
4. You can make several updates to the todo list at once. For example, when you complete a task, you can mark the next task you need to start as in_progress.
|
||||
|
||||
## When NOT to Use This Tool
|
||||
It is important to skip using this tool when:
|
||||
1. There is only a single, straightforward task
|
||||
2. The task is trivial and tracking it provides no benefit
|
||||
3. The task can be completed in less than 3 trivial steps
|
||||
4. The task is purely conversational or informational
|
||||
|
||||
## Task States and Management
|
||||
|
||||
1. **Task States**: Use these states to track progress:
|
||||
- pending: Task not yet started
|
||||
- in_progress: Currently working on (you can have multiple tasks in_progress at a time if they are not related to each other and can be run in parallel)
|
||||
- completed: Task finished successfully
|
||||
|
||||
2. **Task Management**:
|
||||
- Update task status in real-time as you work
|
||||
- Mark tasks complete IMMEDIATELY after finishing (don't batch completions)
|
||||
- Complete current tasks before starting new ones
|
||||
- Remove tasks that are no longer relevant from the list entirely
|
||||
- IMPORTANT: When you write this todo list, you should mark your first task (or tasks) as in_progress immediately!.
|
||||
- IMPORTANT: Unless all tasks are completed, you should always have at least one task in_progress to show the user that you are working on something.
|
||||
|
||||
3. **Task Completion Requirements**:
|
||||
- ONLY mark a task as completed when you have FULLY accomplished it
|
||||
- If you encounter errors, blockers, or cannot finish, keep the task as in_progress
|
||||
- When blocked, create a new task describing what needs to be resolved
|
||||
- Never mark a task as completed if:
|
||||
- There are unresolved issues or errors
|
||||
- Work is partial or incomplete
|
||||
- You encountered blockers that prevent completion
|
||||
- You couldn't find necessary resources or dependencies
|
||||
- Quality standards haven't been met
|
||||
|
||||
4. **Task Breakdown**:
|
||||
- Create specific, actionable items
|
||||
- Break complex tasks into smaller, manageable steps
|
||||
- Use clear, descriptive task names
|
||||
|
||||
Being proactive with task management demonstrates attentiveness and ensures you complete all requirements successfully
|
||||
Remember: If you only need to make a few tool calls to complete a task, and it is clear what you need to do, it is better to just do the task directly and NOT call this tool at all."""
|
||||
|
||||
WRITE_TODOS_SYSTEM_PROMPT = """## `write_todos`
|
||||
|
||||
You have access to the `write_todos` tool to help you manage and plan complex objectives.
|
||||
Use this tool for complex objectives to ensure that you are tracking each necessary step and giving the user visibility into your progress.
|
||||
This tool is very helpful for planning complex objectives, and for breaking down these larger complex objectives into smaller steps.
|
||||
|
||||
It is critical that you mark todos as completed as soon as you are done with a step. Do not batch up multiple steps before marking them as completed.
|
||||
For simple objectives that only require a few steps, it is better to just complete the objective directly and NOT use this tool.
|
||||
Writing todos takes time and tokens, use it when it is helpful for managing complex many-step problems! But not for simple few-step requests.
|
||||
|
||||
## Important To-Do List Usage Notes to Remember
|
||||
- The `write_todos` tool should never be called multiple times in parallel.
|
||||
- Don't be afraid to revise the To-Do list as you go. New information may reveal new tasks that need to be done, or old tasks that are irrelevant."""
|
||||
|
||||
|
||||
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
||||
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PlanningMiddleware(AgentMiddleware):
|
||||
"""Middleware that provides todo list management capabilities to agents.
|
||||
|
||||
This middleware adds a `write_todos` tool that allows agents to create and manage
|
||||
structured task lists for complex multi-step operations. It's designed to help
|
||||
agents track progress, organize complex tasks, and provide users with visibility
|
||||
into task completion status.
|
||||
|
||||
The middleware automatically injects system prompts that guide the agent on when
|
||||
and how to use the todo functionality effectively.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware.planning import PlanningMiddleware
|
||||
from langchain.agents import create_agent
|
||||
|
||||
agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
|
||||
|
||||
# Agent now has access to write_todos tool and todo state tracking
|
||||
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
||||
|
||||
print(result["todos"]) # Array of todo items with status tracking
|
||||
```
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
If not provided, uses the default ``WRITE_TODOS_SYSTEM_PROMPT``.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
If not provided, uses the default ``WRITE_TODOS_TOOL_DESCRIPTION``.
|
||||
"""
|
||||
|
||||
state_schema = PlanningState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
||||
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
) -> None:
|
||||
"""Initialize the PlanningMiddleware with optional custom prompts.
|
||||
|
||||
Args:
|
||||
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
||||
tool_description: Custom description for the write_todos tool.
|
||||
"""
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
self.tool_description = tool_description
|
||||
|
||||
# Dynamically create the write_todos tool with the custom description
|
||||
@tool(description=self.tool_description)
|
||||
def write_todos(
|
||||
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> Command:
|
||||
"""Create and manage a structured task list for your current work session."""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [
|
||||
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
self.tools = [write_todos]
|
||||
|
||||
def modify_model_request( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: PlanningState, # noqa: ARG002
|
||||
) -> ModelRequest:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
request.system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return request
|
||||
@@ -1,14 +1,9 @@
|
||||
import pytest
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from types import ModuleType
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
import warnings
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -18,31 +13,42 @@ from langchain_core.messages import (
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool, InjectedToolCallId
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import InjectedToolCallId, tool
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
from pydantic import BaseModel, Field
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.tools import InjectedState
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
ActionRequest,
|
||||
HumanInTheLoopMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.planning import (
|
||||
PlanningMiddleware,
|
||||
PlanningState,
|
||||
WRITE_TODOS_SYSTEM_PROMPT,
|
||||
write_todos,
|
||||
WRITE_TODOS_TOOL_DESCRIPTION,
|
||||
)
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.types import Command
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
from langchain.tools import InjectedState
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
@@ -1105,14 +1111,9 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1136,9 +1137,6 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
raise Exception("Model error")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1155,14 +1153,9 @@ def test_summarization_middleware_full_workflow() -> None:
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -1423,3 +1416,248 @@ def test_jump_to_is_ephemeral() -> None:
|
||||
agent = agent.compile()
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert "jump_to" not in result
|
||||
|
||||
|
||||
# Tests for PlanningMiddleware
|
||||
def test_planning_middleware_initialization() -> None:
|
||||
"""Test that PlanningMiddleware initializes correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
assert middleware.state_schema == PlanningState
|
||||
assert len(middleware.tools) == 1
|
||||
assert middleware.tools[0].name == "write_todos"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original_prompt,expected_prompt_prefix",
|
||||
[
|
||||
("Original prompt", "Original prompt\n\n## `write_todos`"),
|
||||
(None, "## `write_todos`"),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_modify_model_request(original_prompt, expected_prompt_prefix) -> None:
|
||||
"""Test that modify_model_request handles system prompts correctly."""
|
||||
middleware = PlanningMiddleware()
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=original_prompt,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt.startswith(expected_prompt_prefix)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"todos,expected_message",
|
||||
[
|
||||
([], "Updated todo list to []"),
|
||||
(
|
||||
[{"content": "Task 1", "status": "pending"}],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}]",
|
||||
),
|
||||
(
|
||||
[
|
||||
{"content": "Task 1", "status": "pending"},
|
||||
{"content": "Task 2", "status": "in_progress"},
|
||||
{"content": "Task 3", "status": "completed"},
|
||||
],
|
||||
"Updated todo list to [{'content': 'Task 1', 'status': 'pending'}, {'content': 'Task 2', 'status': 'in_progress'}, {'content': 'Task 3', 'status': 'completed'}]",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_execution(todos, expected_message) -> None:
|
||||
"""Test that the write_todos tool executes correctly."""
|
||||
tool_call = {
|
||||
"args": {"todos": todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
result = write_todos.invoke(tool_call)
|
||||
assert result.update["todos"] == todos
|
||||
assert result.update["messages"][0].content == expected_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_todos",
|
||||
[
|
||||
[{"content": "Task 1", "status": "invalid_status"}],
|
||||
[{"status": "pending"}],
|
||||
],
|
||||
)
|
||||
def test_planning_middleware_write_todos_tool_validation_errors(invalid_todos) -> None:
|
||||
"""Test that the write_todos tool rejects invalid input."""
|
||||
tool_call = {
|
||||
"args": {"todos": invalid_todos},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
with pytest.raises(Exception):
|
||||
write_todos.invoke(tool_call)
|
||||
|
||||
|
||||
def test_planning_middleware_agent_creation_with_middleware() -> None:
|
||||
"""Test that an agent can be created with the planning middleware."""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "pending"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "in_progress"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Task 1", "status": "completed"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
middleware = PlanningMiddleware()
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Task 1", "status": "completed"}]
|
||||
|
||||
# human message (1)
|
||||
# ai message (2) - initial todo
|
||||
# tool message (3)
|
||||
# ai message (4) - updated todo
|
||||
# tool message (5)
|
||||
# ai message (6) - complete todo
|
||||
# tool message (7)
|
||||
# ai message (8) - no tool calls
|
||||
assert len(result["messages"]) == 8
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom system prompt."""
|
||||
custom_system_prompt = "Custom todo system prompt for testing"
|
||||
middleware = PlanningMiddleware(system_prompt=custom_system_prompt)
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="Original prompt",
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
|
||||
|
||||
def test_planning_middleware_custom_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with custom tool description."""
|
||||
custom_tool_description = "Custom tool description for testing"
|
||||
middleware = PlanningMiddleware(tool_description=custom_tool_description)
|
||||
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
"""Test that PlanningMiddleware can be initialized with both custom prompts."""
|
||||
custom_system_prompt = "Custom system prompt"
|
||||
custom_tool_description = "Custom tool description"
|
||||
middleware = PlanningMiddleware(
|
||||
system_prompt=custom_system_prompt,
|
||||
tool_description=custom_tool_description,
|
||||
)
|
||||
|
||||
# Verify system prompt
|
||||
model = FakeToolCallingModel()
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
modified_request = middleware.modify_model_request(request, state)
|
||||
assert modified_request.system_prompt == custom_system_prompt
|
||||
|
||||
# Verify tool description
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == custom_tool_description
|
||||
|
||||
|
||||
def test_planning_middleware_default_prompts() -> None:
|
||||
"""Test that PlanningMiddleware uses default prompts when none provided."""
|
||||
middleware = PlanningMiddleware()
|
||||
|
||||
# Verify default system prompt
|
||||
assert middleware.system_prompt == WRITE_TODOS_SYSTEM_PROMPT
|
||||
|
||||
# Verify default tool description
|
||||
assert middleware.tool_description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
assert len(middleware.tools) == 1
|
||||
tool = middleware.tools[0]
|
||||
assert tool.description == WRITE_TODOS_TOOL_DESCRIPTION
|
||||
|
||||
|
||||
def test_planning_middleware_custom_system_prompt() -> None:
|
||||
"""Test that custom tool executes correctly in an agent."""
|
||||
middleware = PlanningMiddleware(system_prompt="call the write_todos tool")
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"args": {"todos": [{"content": "Custom task", "status": "pending"}]},
|
||||
"name": "write_todos",
|
||||
"type": "tool_call",
|
||||
"id": "test_call",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware])
|
||||
agent = agent.compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
|
||||
# assert custom system prompt is in the first AI message
|
||||
assert "call the write_todos tool" in result["messages"][1].content
|
||||
|
||||
Reference in New Issue
Block a user