chore(langchain): fix types in test_diagram and test_sync_async_wrappers (#34591)

This commit is contained in:
Christophe Bornet
2026-01-05 15:05:24 +01:00
committed by GitHub
parent 9495eb348d
commit 7ca0efde04
2 changed files with 105 additions and 39 deletions

View File

@@ -1,87 +1,89 @@
from collections.abc import Callable
from typing import Any
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
from syrupy.assertion import SnapshotAssertion
from langchain.agents import AgentState
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, ModelResponse
from tests.unit_tests.agents.model import FakeToolCallingModel
def test_create_agent_diagram(
snapshot: SnapshotAssertion,
):
) -> None:
class NoopOne(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopTwo(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopThree(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopFour(AgentMiddleware):
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopFive(AgentMiddleware):
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopSix(AgentMiddleware):
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopSeven(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopEight(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopNine(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopTen(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request)
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
class NoopEleven(AgentMiddleware):
def before_model(self, state, runtime):
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request)
def after_model(self, state, runtime):
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
pass
agent_zero = create_agent(

View File

@@ -6,10 +6,14 @@ These tests verify the desired behavior:
3. If middleware defines only async -> use on async path, raise NotImplementedError on sync path
"""
from collections.abc import Awaitable, Callable
from typing import Any
import pytest
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import Command
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, ToolCallRequest, wrap_tool_call
@@ -36,7 +40,11 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
class SyncOnlyMiddleware(AgentMiddleware):
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
call_log.append("sync_called")
return handler(request)
@@ -68,7 +76,11 @@ class TestSyncAsyncMiddlewareComposition:
"""Middleware with only sync wrap_tool_call raises NotImplementedError on async path."""
class SyncOnlyMiddleware(AgentMiddleware):
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return handler(request)
model = FakeToolCallingModel(
@@ -97,7 +109,11 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
class AsyncOnlyMiddleware(AgentMiddleware):
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
call_log.append("async_called")
return await handler(request)
@@ -129,7 +145,11 @@ class TestSyncAsyncMiddlewareComposition:
"""Middleware with only async awrap_tool_call raises NotImplementedError on sync path."""
class AsyncOnlyMiddleware(AgentMiddleware):
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await handler(request)
model = FakeToolCallingModel(
@@ -157,11 +177,19 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
class BothSyncAsyncMiddleware(AgentMiddleware):
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
call_log.append("sync_called")
return handler(request)
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
call_log.append("async_called")
return await handler(request)
@@ -195,11 +223,19 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
class BothSyncAsyncMiddleware(AgentMiddleware):
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
call_log.append("sync_called")
return handler(request)
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
call_log.append("async_called")
return await handler(request)
@@ -234,13 +270,21 @@ class TestSyncAsyncMiddlewareComposition:
class SyncOnlyMiddleware(AgentMiddleware):
name = "SyncOnly"
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return handler(request)
class AsyncOnlyMiddleware(AgentMiddleware):
name = "AsyncOnly"
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await handler(request)
model = FakeToolCallingModel(
@@ -273,13 +317,21 @@ class TestSyncAsyncMiddlewareComposition:
class SyncOnlyMiddleware(AgentMiddleware):
name = "SyncOnly"
def wrap_tool_call(self, request, handler):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return handler(request)
class AsyncOnlyMiddleware(AgentMiddleware):
name = "AsyncOnly"
async def awrap_tool_call(self, request, handler):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await handler(request)
model = FakeToolCallingModel(
@@ -311,7 +363,10 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
@wrap_tool_call
def my_wrapper(request: ToolCallRequest, handler):
def my_wrapper(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
call_log.append("decorator_sync")
return handler(request)
@@ -343,7 +398,10 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
@wrap_tool_call
def my_wrapper(request: ToolCallRequest, handler):
def my_wrapper(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
call_log.append("decorator_sync")
return handler(request)
@@ -373,7 +431,10 @@ class TestSyncAsyncMiddlewareComposition:
call_log = []
@wrap_tool_call
async def my_async_wrapper(request: ToolCallRequest, handler):
async def my_async_wrapper(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
call_log.append("decorator_async")
return await handler(request)
@@ -402,7 +463,10 @@ class TestSyncAsyncMiddlewareComposition:
"""Decorator-created async-only middleware raises on sync path."""
@wrap_tool_call
async def my_async_wrapper(request: ToolCallRequest, handler):
async def my_async_wrapper(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await handler(request)
model = FakeToolCallingModel(