chore(langchain): fix types in test_tools (#34592)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-01-10 00:05:28 +01:00
committed by GitHub
parent 4e40c2766a
commit 61fd703e5f

View File

@@ -1,14 +1,22 @@
"""Test Middleware handling of tools in agents."""
from collections.abc import Callable
from typing import Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import BaseTool
from langgraph.prebuilt.tool_node import ToolNode
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
)
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -30,8 +38,8 @@ def test_model_request_tools_are_base_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
captured_requests.append(request)
return handler(request)
@@ -51,7 +59,15 @@ def test_model_request_tools_are_base_tools() -> None:
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert {t.name for t in request.tools} == {"search_tool", "calculator"}
tools = []
for t in request.tools:
assert isinstance(t, BaseTool)
tools.append(t.name)
assert set(tools) == {
"search_tool",
"calculator",
}
def test_middleware_can_modify_tools() -> None:
@@ -76,10 +92,14 @@ def test_middleware_can_modify_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Only allow tool_a and tool_b
filtered_tools = [t for t in request.tools if t.name in {"tool_a", "tool_b"}]
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
if t.name in {"tool_a", "tool_b"}:
filtered_tools.append(t)
return handler(request.override(tools=filtered_tools))
# Model will try to call tool_a
@@ -120,8 +140,8 @@ def test_unknown_tool_raises_error() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Add an unknown tool
return handler(request.override(tools=[*request.tools, unknown_tool]))
@@ -149,7 +169,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
"""Admin-only tool."""
return f"Admin: {command}"
class AdminState(AgentState):
class AdminState(AgentState[Any]):
is_admin: bool
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
@@ -158,11 +178,15 @@ def test_middleware_can_add_and_remove_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Remove admin_tool if not admin
if not request.state.get("is_admin", False):
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
if t.name != "admin_tool":
filtered_tools.append(t)
request = request.override(tools=filtered_tools)
return handler(request)
@@ -197,8 +221,8 @@ def test_empty_tools_list_is_valid() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Remove all tools
request = request.override(tools=[])
return handler(request)
@@ -240,11 +264,17 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Remove tool_c
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
tools: list[str] = []
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
tools.append(t.name)
# Remove tool_c
if t.name != "tool_c":
filtered_tools.append(t)
modification_order.append(tools)
request = request.override(tools=filtered_tools)
return handler(request)
@@ -252,13 +282,19 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Should not see tool_c here
assert all(t.name != "tool_c" for t in request.tools)
# Remove tool_b
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
tools: list[str] = []
filtered_tools: list[BaseTool | dict[str, Any]] = []
for t in request.tools:
assert isinstance(t, BaseTool)
# Should not see tool_c here
assert t.name != "tool_c"
tools.append(t.name)
# Remove tool_b
if t.name != "tool_b":
filtered_tools.append(t)
modification_order.append(tools)
request = request.override(tools=filtered_tools)
return handler(request)
@@ -317,6 +353,7 @@ def test_middleware_with_additional_tools() -> None:
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].name == "middleware_tool"
assert isinstance(tool_messages[0].content, str)
assert "middleware" in tool_messages[0].content.lower()