diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py index fbdcb4fd056..65fb57a90fd 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py @@ -7,18 +7,27 @@ This module tests the wrap_model_call functionality in three forms: """ from collections.abc import Awaitable, Callable +from typing import Any import pytest +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.fake_chat_models import GenericFakeChatModel -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatResult from langchain_core.tools import tool -from typing_extensions import TypedDict +from langgraph.runtime import Runtime +from typing_extensions import TypedDict, override -from langchain.agents import create_agent +from langchain.agents import AgentState, create_agent from langchain.agents.middleware.types import ( AgentMiddleware, + ModelCallResult, ModelRequest, + ModelResponse, wrap_model_call, ) from tests.unit_tests.agents.model import FakeToolCallingModel @@ -31,7 +40,11 @@ class TestBasicWrapModelCall: """Test middleware that simply passes through without modification.""" class PassthroughMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")])) @@ -47,7 +60,11 @@ class TestBasicWrapModelCall: call_log = [] class LoggingMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: call_log.append("before") result = handler(request) call_log.append("after") @@ -65,11 +82,15 @@ class TestBasicWrapModelCall: """Test middleware that counts model calls.""" class CountingMiddleware(AgentMiddleware): - def __init__(self): + def __init__(self) -> None: super().__init__() self.call_count = 0 - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: self.call_count += 1 return handler(request) @@ -90,7 +111,14 @@ class TestRetryLogic: call_count = {"value": 0} class FailOnceThenSucceed(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First call fails" @@ -98,11 +126,15 @@ class TestRetryLogic: return super()._generate(messages, **kwargs) class RetryOnceMiddleware(AgentMiddleware): - def __init__(self): + def __init__(self) -> None: super().__init__() self.retry_count = 0 - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except Exception: @@ -122,17 +154,28 @@ class TestRetryLogic: """Test middleware with maximum retry limit.""" class AlwaysFailModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Always fails" raise ValueError(msg) class MaxRetriesMiddleware(AgentMiddleware): - def __init__(self, max_retries=3): + def __init__(self, max_retries: int = 3): super().__init__() self.max_retries = max_retries self.attempts = [] - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: last_exception = None for attempt in range(self.max_retries): self.attempts.append(attempt + 1) @@ -161,16 +204,27 @@ class TestRetryLogic: class FailingModel(BaseChatModel): """Model that always fails.""" - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Model error" raise ValueError(msg) @property - def _llm_type(self): + def _llm_type(self) -> str: return "failing" class NoRetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()]) @@ -184,12 +238,19 @@ class TestRetryLogic: class AlwaysFailingModel(BaseChatModel): """Model that always fails.""" - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Always fails" raise ValueError(msg) @property - def _llm_type(self): + def _llm_type(self) -> str: return "always_failing" class LimitedRetryMiddleware(AgentMiddleware): @@ -200,7 +261,11 @@ class TestRetryLogic: self.max_retries = max_retries self.attempt_count = 0 - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: last_exception = None for _attempt in range(self.max_retries): self.attempt_count += 1 @@ -235,10 +300,15 @@ class TestResponseRewriting: """Test middleware that transforms response to uppercase.""" class UppercaseMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] + assert isinstance(ai_message.content, str) return AIMessage(content=ai_message.content.upper()) model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")])) @@ -256,7 +326,11 @@ class TestResponseRewriting: super().__init__() self.prefix = prefix - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] @@ -273,12 +347,17 @@ class TestResponseRewriting: """Test middleware applying multiple transformations.""" class MultiTransformMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] # First transformation: uppercase + assert isinstance(ai_message.content, str) content = ai_message.content.upper() # Second transformation: add prefix and suffix content = f"[START] {content} [END]" @@ -299,12 +378,23 @@ class TestErrorHandling: """Test middleware that converts errors to successful responses.""" class AlwaysFailModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Model error" raise ValueError(msg) class ErrorToSuccessMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except Exception as e: @@ -323,12 +413,23 @@ class TestErrorHandling: """Test middleware that only handles specific errors.""" class SpecificErrorModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Network error" raise ConnectionError(msg) class SelectiveErrorMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except ConnectionError: @@ -346,7 +447,11 @@ class TestErrorHandling: call_log = [] class ErrorRecoveryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: call_log.append("before-yield") result = handler(request) @@ -369,7 +474,14 @@ class TestErrorHandling: call_log.clear() class AlwaysFailModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Model error" raise ValueError(msg) @@ -386,11 +498,15 @@ class TestShortCircuit: def test_cache_short_circuit(self) -> None: """Test middleware that short-circuits with cached response.""" - cache = {} + cache: dict[str, ModelResponse] = {} model_calls = [] class CachingMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Simple cache key based on last message cache_key = str(request.messages[-1].content) if request.messages else "" @@ -403,7 +519,14 @@ class TestShortCircuit: return result class TrackingModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: model_calls.append(len(messages)) return super()._generate(messages, **kwargs) @@ -445,7 +568,11 @@ class TestRequestModification: super().__init__() self.system_prompt = system_prompt - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Modify request to add system prompt modified_request = ModelRequest( model=request.model, @@ -482,7 +609,11 @@ class TestStateAndRuntime: state_values = [] class StateAwareMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Access state from request state_values.append( { @@ -504,7 +635,11 @@ class TestStateAndRuntime: """Test middleware that tracks retry count in state.""" class StateTrackingRetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: max_retries = 2 for attempt in range(max_retries): try: @@ -517,7 +652,14 @@ class TestStateAndRuntime: call_count = {"value": 0} class FailOnceThenSucceed(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First fails" @@ -541,14 +683,22 @@ class TestMiddlewareComposition: execution_order = [] class OuterMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("outer-before") response = handler(request) execution_order.append("outer-after") return response class InnerMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("inner-before") response = handler(request) execution_order.append("inner-after") @@ -572,21 +722,33 @@ class TestMiddlewareComposition: execution_order = [] class FirstMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("first-before") response = handler(request) execution_order.append("first-after") return response class SecondMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("second-before") response = handler(request) execution_order.append("second-after") return response class ThirdMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("third-before") response = handler(request) execution_order.append("third-after") @@ -617,7 +779,14 @@ class TestMiddlewareComposition: log = [] class FailOnceThenSucceed(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First call fails" @@ -625,14 +794,22 @@ class TestMiddlewareComposition: return super()._generate(messages, **kwargs) class LoggingMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: log.append("logging-before") result = handler(request) log.append("logging-after") return result class RetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: log.append("retry-before") try: result = handler(request) @@ -664,14 +841,22 @@ class TestMiddlewareComposition: """Test multiple middleware that each transform the response.""" class PrefixMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] return AIMessage(content=f"[PREFIX] {ai_message.content}") class SuffixMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] @@ -692,7 +877,14 @@ class TestMiddlewareComposition: call_count = {"value": 0} class FailOnceThenSucceed(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First call fails" @@ -700,17 +892,26 @@ class TestMiddlewareComposition: return super()._generate(messages, **kwargs) class RetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except Exception: return handler(request) class UppercaseMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] + assert isinstance(ai_message.content, str) return AIMessage(content=ai_message.content.upper()) model = FailOnceThenSucceed(messages=iter([AIMessage(content="success")])) @@ -728,14 +929,22 @@ class TestMiddlewareComposition: model_calls = [] class OuterMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("outer-before") result = handler(request) execution_order.append("outer-after") return result class MiddleRetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("middle-before") # Always retry once (call handler twice) result = handler(request) @@ -745,14 +954,25 @@ class TestMiddlewareComposition: return result class InnerMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("inner-before") result = handler(request) execution_order.append("inner-after") return result class TrackingModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: model_calls.append(len(messages)) return super()._generate(messages, **kwargs) @@ -789,7 +1009,10 @@ class TestWrapModelCallDecorator: """Test basic decorator usage without parameters.""" @wrap_model_call - def passthrough_middleware(request, handler): + def passthrough_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) # Should return an AgentMiddleware instance @@ -807,7 +1030,10 @@ class TestWrapModelCallDecorator: """Test decorator with custom middleware name.""" @wrap_model_call(name="CustomMiddleware") - def my_middleware(request, handler): + def my_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) assert isinstance(my_middleware, AgentMiddleware) @@ -818,7 +1044,14 @@ class TestWrapModelCallDecorator: call_count = {"value": 0} class FailOnceThenSucceed(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First call fails" @@ -826,7 +1059,10 @@ class TestWrapModelCallDecorator: return super()._generate(messages, **kwargs) @wrap_model_call - def retry_once(request, handler): + def retry_once( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except Exception: @@ -845,10 +1081,14 @@ class TestWrapModelCallDecorator: """Test decorator for rewriting responses.""" @wrap_model_call - def uppercase_responses(request, handler): + def uppercase_responses( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: result = handler(request) # result is ModelResponse, extract AIMessage from it ai_message = result.result[0] + assert isinstance(ai_message.content, str) return AIMessage(content=ai_message.content.upper()) model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")])) @@ -862,12 +1102,22 @@ class TestWrapModelCallDecorator: """Test decorator for error recovery.""" class AlwaysFailModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Model error" raise ValueError(msg) @wrap_model_call - def error_to_fallback(request, handler): + def error_to_fallback( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: try: return handler(request) except Exception: @@ -885,7 +1135,10 @@ class TestWrapModelCallDecorator: state_values = [] @wrap_model_call - def log_state(request, handler): + def log_state( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: state_values.append(request.state.get("messages")) return handler(request) @@ -904,14 +1157,20 @@ class TestWrapModelCallDecorator: execution_order = [] @wrap_model_call - def outer_middleware(request, handler): + def outer_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("outer-before") result = handler(request) execution_order.append("outer-after") return result @wrap_model_call - def inner_middleware(request, handler): + def inner_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("inner-before") result = handler(request) execution_order.append("inner-after") @@ -933,11 +1192,14 @@ class TestWrapModelCallDecorator: """Test decorator with custom state schema.""" class CustomState(TypedDict): - messages: list + messages: list[Any] custom_field: str @wrap_model_call(state_schema=CustomState) - def middleware_with_schema(request, handler): + def middleware_with_schema( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) assert isinstance(middleware_with_schema, AgentMiddleware) @@ -953,7 +1215,10 @@ class TestWrapModelCallDecorator: return f"Result: {query}" @wrap_model_call(tools=[test_tool]) - def middleware_with_tools(request, handler): + def middleware_with_tools( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) assert isinstance(middleware_with_tools, AgentMiddleware) @@ -965,12 +1230,18 @@ class TestWrapModelCallDecorator: # Without parentheses @wrap_model_call - def middleware_no_parens(request, handler): + def middleware_no_parens( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) # With parentheses @wrap_model_call() - def middleware_with_parens(request, handler): + def middleware_with_parens( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) assert isinstance(middleware_no_parens, AgentMiddleware) @@ -980,7 +1251,10 @@ class TestWrapModelCallDecorator: """Test that decorator uses function name for class name.""" @wrap_model_call - def my_custom_middleware(request, handler): + def my_custom_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: return handler(request) assert my_custom_middleware.__class__.__name__ == "my_custom_middleware" @@ -990,14 +1264,21 @@ class TestWrapModelCallDecorator: execution_order = [] @wrap_model_call - def decorated_middleware(request, handler): + def decorated_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("decorated-before") result = handler(request) execution_order.append("decorated-after") return result class ClassMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: execution_order.append("class-before") result = handler(request) execution_order.append("class-after") @@ -1025,7 +1306,14 @@ class TestWrapModelCallDecorator: call_count = {"value": 0} class UnreliableModel(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] <= 2: msg = f"Attempt {call_count['value']} failed" @@ -1033,7 +1321,10 @@ class TestWrapModelCallDecorator: return super()._generate(messages, **kwargs) @wrap_model_call - def retry_with_tracking(request, handler): + def retry_with_tracking( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: max_retries = 3 for attempt in range(max_retries): attempts.append(attempt + 1) @@ -1059,7 +1350,10 @@ class TestWrapModelCallDecorator: modified_prompts = [] @wrap_model_call - def add_system_prompt(request, handler): + def add_system_prompt( + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Modify request to add system prompt modified_request = ModelRequest( messages=request.messages, @@ -1068,7 +1362,7 @@ class TestWrapModelCallDecorator: tool_choice=request.tool_choice, tools=request.tools, response_format=request.response_format, - state={}, + state=AgentState[Any](messages=[]), runtime=None, ) modified_prompts.append(modified_request.system_prompt) @@ -1090,7 +1384,11 @@ class TestAsyncWrapModelCall: log = [] class LoggingMiddleware(AgentMiddleware): - async def awrap_model_call(self, request, handler): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: log.append("before") result = await handler(request) log.append("after") @@ -1109,7 +1407,14 @@ class TestAsyncWrapModelCall: call_count = {"value": 0} class AsyncFailOnceThenSucceed(GenericFakeChatModel): - async def _agenerate(self, messages, **kwargs): + @override + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First async call fails" @@ -1117,7 +1422,11 @@ class TestAsyncWrapModelCall: return await super()._agenerate(messages, **kwargs) class RetryMiddleware(AgentMiddleware): - async def awrap_model_call(self, request, handler): + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: try: return await handler(request) except Exception: @@ -1136,7 +1445,10 @@ class TestAsyncWrapModelCall: call_log = [] @wrap_model_call - async def logging_middleware(request, handler): + async def logging_middleware( + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: call_log.append("before") result = await handler(request) call_log.append("after") @@ -1161,8 +1473,8 @@ class TestSyncAsyncInterop: async def awrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[AIMessage]], - ) -> AIMessage: + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: return await handler(request) agent = create_agent( @@ -1180,25 +1492,27 @@ class TestSyncAsyncInterop: calls = [] class MixedMiddleware(AgentMiddleware): - def before_model(self, state, runtime) -> None: + @override + def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None: calls.append("MixedMiddleware.before_model") - async def abefore_model(self, state, runtime) -> None: + @override + async def abefore_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None: calls.append("MixedMiddleware.abefore_model") def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], AIMessage], - ) -> AIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: calls.append("MixedMiddleware.wrap_model_call") return handler(request) async def awrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], Awaitable[AIMessage]], - ) -> AIMessage: + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: calls.append("MixedMiddleware.awrap_model_call") return await handler(request) @@ -1226,7 +1540,11 @@ class TestEdgeCases: modified_messages = [] class RequestModifyingMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Add a system message to the request modified_request = request modified_messages.append(len(modified_request.messages)) @@ -1244,7 +1562,11 @@ class TestEdgeCases: attempts = [] class MultiModelRetryMiddleware(AgentMiddleware): - def wrap_model_call(self, request, handler): + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: attempts.append("first-attempt") try: return handler(request) @@ -1255,7 +1577,14 @@ class TestEdgeCases: call_count = {"value": 0} class FailFirstSucceedSecond(GenericFakeChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: call_count["value"] += 1 if call_count["value"] == 1: msg = "First fails"