mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
chore(langchain): fix types in test_tools (#34592)
Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
4e40c2766a
commit
61fd703e5f
@@ -1,14 +1,22 @@
|
||||
"""Test Middleware handling of tools in agents."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.tools.base import BaseTool
|
||||
from langgraph.prebuilt.tool_node import ToolNode
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@@ -30,8 +38,8 @@ def test_model_request_tools_are_base_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
captured_requests.append(request)
|
||||
return handler(request)
|
||||
|
||||
@@ -51,7 +59,15 @@ def test_model_request_tools_are_base_tools() -> None:
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert {t.name for t in request.tools} == {"search_tool", "calculator"}
|
||||
|
||||
tools = []
|
||||
for t in request.tools:
|
||||
assert isinstance(t, BaseTool)
|
||||
tools.append(t.name)
|
||||
assert set(tools) == {
|
||||
"search_tool",
|
||||
"calculator",
|
||||
}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tools() -> None:
|
||||
@@ -76,10 +92,14 @@ def test_middleware_can_modify_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Only allow tool_a and tool_b
|
||||
filtered_tools = [t for t in request.tools if t.name in {"tool_a", "tool_b"}]
|
||||
filtered_tools: list[BaseTool | dict[str, Any]] = []
|
||||
for t in request.tools:
|
||||
assert isinstance(t, BaseTool)
|
||||
if t.name in {"tool_a", "tool_b"}:
|
||||
filtered_tools.append(t)
|
||||
return handler(request.override(tools=filtered_tools))
|
||||
|
||||
# Model will try to call tool_a
|
||||
@@ -120,8 +140,8 @@ def test_unknown_tool_raises_error() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Add an unknown tool
|
||||
return handler(request.override(tools=[*request.tools, unknown_tool]))
|
||||
|
||||
@@ -149,7 +169,7 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
"""Admin-only tool."""
|
||||
return f"Admin: {command}"
|
||||
|
||||
class AdminState(AgentState):
|
||||
class AdminState(AgentState[Any]):
|
||||
is_admin: bool
|
||||
|
||||
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
|
||||
@@ -158,11 +178,15 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Remove admin_tool if not admin
|
||||
if not request.state.get("is_admin", False):
|
||||
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
filtered_tools: list[BaseTool | dict[str, Any]] = []
|
||||
for t in request.tools:
|
||||
assert isinstance(t, BaseTool)
|
||||
if t.name != "admin_tool":
|
||||
filtered_tools.append(t)
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
@@ -197,8 +221,8 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Remove all tools
|
||||
request = request.override(tools=[])
|
||||
return handler(request)
|
||||
@@ -240,11 +264,17 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
tools: list[str] = []
|
||||
filtered_tools: list[BaseTool | dict[str, Any]] = []
|
||||
for t in request.tools:
|
||||
assert isinstance(t, BaseTool)
|
||||
tools.append(t.name)
|
||||
# Remove tool_c
|
||||
if t.name != "tool_c":
|
||||
filtered_tools.append(t)
|
||||
modification_order.append(tools)
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
@@ -252,13 +282,19 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Should not see tool_c here
|
||||
assert all(t.name != "tool_c" for t in request.tools)
|
||||
# Remove tool_b
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
tools: list[str] = []
|
||||
filtered_tools: list[BaseTool | dict[str, Any]] = []
|
||||
for t in request.tools:
|
||||
assert isinstance(t, BaseTool)
|
||||
# Should not see tool_c here
|
||||
assert t.name != "tool_c"
|
||||
tools.append(t.name)
|
||||
# Remove tool_b
|
||||
if t.name != "tool_b":
|
||||
filtered_tools.append(t)
|
||||
modification_order.append(tools)
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
@@ -317,6 +353,7 @@ def test_middleware_with_additional_tools() -> None:
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].name == "middleware_tool"
|
||||
assert isinstance(tool_messages[0].content, str)
|
||||
assert "middleware" in tool_messages[0].content.lower()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user