mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
chore(langchain): fix types in test_diagram and test_sync_async_wrappers (#34591)
This commit is contained in:
committed by
GitHub
parent
9495eb348d
commit
7ca0efde04
@@ -1,87 +1,89 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, ModelResponse
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_create_agent_diagram(
|
||||
snapshot: SnapshotAssertion,
|
||||
):
|
||||
) -> None:
|
||||
class NoopOne(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopTwo(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopThree(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopFour(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopFive(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopSix(AgentMiddleware):
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopSeven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopEight(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopNine(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopTen(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
class NoopEleven(AgentMiddleware):
|
||||
def before_model(self, state, runtime):
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
def after_model(self, state: AgentState[Any], runtime: Runtime[None]) -> None:
|
||||
pass
|
||||
|
||||
agent_zero = create_agent(
|
||||
|
||||
@@ -6,10 +6,14 @@ These tests verify the desired behavior:
|
||||
3. If middleware defines only async -> use on async path, raise NotImplementedError on sync path
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ToolCallRequest, wrap_tool_call
|
||||
@@ -36,7 +40,11 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("sync_called")
|
||||
return handler(request)
|
||||
|
||||
@@ -68,7 +76,11 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
"""Middleware with only sync wrap_tool_call raises NotImplementedError on async path."""
|
||||
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return handler(request)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
@@ -97,7 +109,11 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("async_called")
|
||||
return await handler(request)
|
||||
|
||||
@@ -129,7 +145,11 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
"""Middleware with only async awrap_tool_call raises NotImplementedError on sync path."""
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await handler(request)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
@@ -157,11 +177,19 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
class BothSyncAsyncMiddleware(AgentMiddleware):
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("sync_called")
|
||||
return handler(request)
|
||||
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("async_called")
|
||||
return await handler(request)
|
||||
|
||||
@@ -195,11 +223,19 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
class BothSyncAsyncMiddleware(AgentMiddleware):
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("sync_called")
|
||||
return handler(request)
|
||||
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("async_called")
|
||||
return await handler(request)
|
||||
|
||||
@@ -234,13 +270,21 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
name = "SyncOnly"
|
||||
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return handler(request)
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
name = "AsyncOnly"
|
||||
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await handler(request)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
@@ -273,13 +317,21 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
name = "SyncOnly"
|
||||
|
||||
def wrap_tool_call(self, request, handler):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return handler(request)
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
name = "AsyncOnly"
|
||||
|
||||
async def awrap_tool_call(self, request, handler):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await handler(request)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
@@ -311,7 +363,10 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
@wrap_tool_call
|
||||
def my_wrapper(request: ToolCallRequest, handler):
|
||||
def my_wrapper(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("decorator_sync")
|
||||
return handler(request)
|
||||
|
||||
@@ -343,7 +398,10 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
@wrap_tool_call
|
||||
def my_wrapper(request: ToolCallRequest, handler):
|
||||
def my_wrapper(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("decorator_sync")
|
||||
return handler(request)
|
||||
|
||||
@@ -373,7 +431,10 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
call_log = []
|
||||
|
||||
@wrap_tool_call
|
||||
async def my_async_wrapper(request: ToolCallRequest, handler):
|
||||
async def my_async_wrapper(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
call_log.append("decorator_async")
|
||||
return await handler(request)
|
||||
|
||||
@@ -402,7 +463,10 @@ class TestSyncAsyncMiddlewareComposition:
|
||||
"""Decorator-created async-only middleware raises on sync path."""
|
||||
|
||||
@wrap_tool_call
|
||||
async def my_async_wrapper(request: ToolCallRequest, handler):
|
||||
async def my_async_wrapper(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await handler(request)
|
||||
|
||||
model = FakeToolCallingModel(
|
||||
|
||||
Reference in New Issue
Block a user