diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_diagram.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_diagram.py index 6c15500699d..765805f6542 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_diagram.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_diagram.py @@ -1,87 +1,89 @@ from collections.abc import Callable +from typing import Any -from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime from syrupy.assertion import SnapshotAssertion +from langchain.agents import AgentState from langchain.agents.factory import create_agent -from langchain.agents.middleware.types import AgentMiddleware, ModelRequest +from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, ModelResponse from tests.unit_tests.agents.model import FakeToolCallingModel def test_create_agent_diagram( snapshot: SnapshotAssertion, -): +) -> None: class NoopOne(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopTwo(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopThree(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopFour(AgentMiddleware): - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopFive(AgentMiddleware): - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopSix(AgentMiddleware): - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopSeven(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopEight(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopNine(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopTen(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: return handler(request) - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass class NoopEleven(AgentMiddleware): - def before_model(self, state, runtime): + def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: return handler(request) - def after_model(self, state, runtime): + def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None: pass agent_zero = create_agent( diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py index d68e88976b0..140cf3a4b89 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py @@ -6,10 +6,14 @@ These tests verify the desired behavior: 3. If middleware defines only async -> use on async path, raise NotImplementedError on sync path """ +from collections.abc import Awaitable, Callable +from typing import Any + import pytest from langchain_core.messages import HumanMessage, ToolCall, ToolMessage from langchain_core.tools import tool from langgraph.checkpoint.memory import InMemorySaver +from langgraph.types import Command from langchain.agents.factory import create_agent from langchain.agents.middleware.types import AgentMiddleware, ToolCallRequest, wrap_tool_call @@ -36,7 +40,11 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] class SyncOnlyMiddleware(AgentMiddleware): - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: call_log.append("sync_called") return handler(request) @@ -68,7 +76,11 @@ class TestSyncAsyncMiddlewareComposition: """Middleware with only sync wrap_tool_call raises NotImplementedError on async path.""" class SyncOnlyMiddleware(AgentMiddleware): - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: return handler(request) model = FakeToolCallingModel( @@ -97,7 +109,11 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] class AsyncOnlyMiddleware(AgentMiddleware): - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: call_log.append("async_called") return await handler(request) @@ -129,7 +145,11 @@ class TestSyncAsyncMiddlewareComposition: """Middleware with only async awrap_tool_call raises NotImplementedError on sync path.""" class AsyncOnlyMiddleware(AgentMiddleware): - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: return await handler(request) model = FakeToolCallingModel( @@ -157,11 +177,19 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] class BothSyncAsyncMiddleware(AgentMiddleware): - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: call_log.append("sync_called") return handler(request) - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: call_log.append("async_called") return await handler(request) @@ -195,11 +223,19 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] class BothSyncAsyncMiddleware(AgentMiddleware): - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: call_log.append("sync_called") return handler(request) - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: call_log.append("async_called") return await handler(request) @@ -234,13 +270,21 @@ class TestSyncAsyncMiddlewareComposition: class SyncOnlyMiddleware(AgentMiddleware): name = "SyncOnly" - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: return handler(request) class AsyncOnlyMiddleware(AgentMiddleware): name = "AsyncOnly" - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: return await handler(request) model = FakeToolCallingModel( @@ -273,13 +317,21 @@ class TestSyncAsyncMiddlewareComposition: class SyncOnlyMiddleware(AgentMiddleware): name = "SyncOnly" - def wrap_tool_call(self, request, handler): + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: return handler(request) class AsyncOnlyMiddleware(AgentMiddleware): name = "AsyncOnly" - async def awrap_tool_call(self, request, handler): + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: return await handler(request) model = FakeToolCallingModel( @@ -311,7 +363,10 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] @wrap_tool_call - def my_wrapper(request: ToolCallRequest, handler): + def my_wrapper( + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: call_log.append("decorator_sync") return handler(request) @@ -343,7 +398,10 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] @wrap_tool_call - def my_wrapper(request: ToolCallRequest, handler): + def my_wrapper( + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: call_log.append("decorator_sync") return handler(request) @@ -373,7 +431,10 @@ class TestSyncAsyncMiddlewareComposition: call_log = [] @wrap_tool_call - async def my_async_wrapper(request: ToolCallRequest, handler): + async def my_async_wrapper( + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: call_log.append("decorator_async") return await handler(request) @@ -402,7 +463,10 @@ class TestSyncAsyncMiddlewareComposition: """Decorator-created async-only middleware raises on sync path.""" @wrap_tool_call - async def my_async_wrapper(request: ToolCallRequest, handler): + async def my_async_wrapper( + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: return await handler(request) model = FakeToolCallingModel(