chore(langchain): fix types in test_wrap_model_call (#34573)

This commit is contained in:
Christophe Bornet
2026-01-07 17:49:46 +01:00
committed by GitHub
parent 0c7b7e045d
commit f10225184d

View File

@@ -7,18 +7,27 @@ This module tests the wrap_model_call functionality in three forms:
"""
from collections.abc import Awaitable, Callable
from typing import Any
import pytest
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatResult
from langchain_core.tools import tool
from typing_extensions import TypedDict
from langgraph.runtime import Runtime
from typing_extensions import TypedDict, override
from langchain.agents import create_agent
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
wrap_model_call,
)
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -31,7 +40,11 @@ class TestBasicWrapModelCall:
"""Test middleware that simply passes through without modification."""
class PassthroughMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
@@ -47,7 +60,11 @@ class TestBasicWrapModelCall:
call_log = []
class LoggingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
call_log.append("before")
result = handler(request)
call_log.append("after")
@@ -65,11 +82,15 @@ class TestBasicWrapModelCall:
"""Test middleware that counts model calls."""
class CountingMiddleware(AgentMiddleware):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.call_count = 0
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
self.call_count += 1
return handler(request)
@@ -90,7 +111,14 @@ class TestRetryLogic:
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First call fails"
@@ -98,11 +126,15 @@ class TestRetryLogic:
return super()._generate(messages, **kwargs)
class RetryOnceMiddleware(AgentMiddleware):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.retry_count = 0
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except Exception:
@@ -122,17 +154,28 @@ class TestRetryLogic:
"""Test middleware with maximum retry limit."""
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Always fails"
raise ValueError(msg)
class MaxRetriesMiddleware(AgentMiddleware):
def __init__(self, max_retries=3):
def __init__(self, max_retries: int = 3):
super().__init__()
self.max_retries = max_retries
self.attempts = []
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
last_exception = None
for attempt in range(self.max_retries):
self.attempts.append(attempt + 1)
@@ -161,16 +204,27 @@ class TestRetryLogic:
class FailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Model error"
raise ValueError(msg)
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "failing"
class NoRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
@@ -184,12 +238,19 @@ class TestRetryLogic:
class AlwaysFailingModel(BaseChatModel):
"""Model that always fails."""
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Always fails"
raise ValueError(msg)
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "always_failing"
class LimitedRetryMiddleware(AgentMiddleware):
@@ -200,7 +261,11 @@ class TestRetryLogic:
self.max_retries = max_retries
self.attempt_count = 0
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
last_exception = None
for _attempt in range(self.max_retries):
self.attempt_count += 1
@@ -235,10 +300,15 @@ class TestResponseRewriting:
"""Test middleware that transforms response to uppercase."""
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
assert isinstance(ai_message.content, str)
return AIMessage(content=ai_message.content.upper())
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
@@ -256,7 +326,11 @@ class TestResponseRewriting:
super().__init__()
self.prefix = prefix
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
@@ -273,12 +347,17 @@ class TestResponseRewriting:
"""Test middleware applying multiple transformations."""
class MultiTransformMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
# First transformation: uppercase
assert isinstance(ai_message.content, str)
content = ai_message.content.upper()
# Second transformation: add prefix and suffix
content = f"[START] {content} [END]"
@@ -299,12 +378,23 @@ class TestErrorHandling:
"""Test middleware that converts errors to successful responses."""
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Model error"
raise ValueError(msg)
class ErrorToSuccessMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except Exception as e:
@@ -323,12 +413,23 @@ class TestErrorHandling:
"""Test middleware that only handles specific errors."""
class SpecificErrorModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Network error"
raise ConnectionError(msg)
class SelectiveErrorMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except ConnectionError:
@@ -346,7 +447,11 @@ class TestErrorHandling:
call_log = []
class ErrorRecoveryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
call_log.append("before-yield")
result = handler(request)
@@ -369,7 +474,14 @@ class TestErrorHandling:
call_log.clear()
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Model error"
raise ValueError(msg)
@@ -386,11 +498,15 @@ class TestShortCircuit:
def test_cache_short_circuit(self) -> None:
"""Test middleware that short-circuits with cached response."""
cache = {}
cache: dict[str, ModelResponse] = {}
model_calls = []
class CachingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Simple cache key based on last message
cache_key = str(request.messages[-1].content) if request.messages else ""
@@ -403,7 +519,14 @@ class TestShortCircuit:
return result
class TrackingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
model_calls.append(len(messages))
return super()._generate(messages, **kwargs)
@@ -445,7 +568,11 @@ class TestRequestModification:
super().__init__()
self.system_prompt = system_prompt
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Modify request to add system prompt
modified_request = ModelRequest(
model=request.model,
@@ -482,7 +609,11 @@ class TestStateAndRuntime:
state_values = []
class StateAwareMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Access state from request
state_values.append(
{
@@ -504,7 +635,11 @@ class TestStateAndRuntime:
"""Test middleware that tracks retry count in state."""
class StateTrackingRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
max_retries = 2
for attempt in range(max_retries):
try:
@@ -517,7 +652,14 @@ class TestStateAndRuntime:
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First fails"
@@ -541,14 +683,22 @@ class TestMiddlewareComposition:
execution_order = []
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("outer-before")
response = handler(request)
execution_order.append("outer-after")
return response
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("inner-before")
response = handler(request)
execution_order.append("inner-after")
@@ -572,21 +722,33 @@ class TestMiddlewareComposition:
execution_order = []
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("first-before")
response = handler(request)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("second-before")
response = handler(request)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("third-before")
response = handler(request)
execution_order.append("third-after")
@@ -617,7 +779,14 @@ class TestMiddlewareComposition:
log = []
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First call fails"
@@ -625,14 +794,22 @@ class TestMiddlewareComposition:
return super()._generate(messages, **kwargs)
class LoggingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
log.append("logging-before")
result = handler(request)
log.append("logging-after")
return result
class RetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
log.append("retry-before")
try:
result = handler(request)
@@ -664,14 +841,22 @@ class TestMiddlewareComposition:
"""Test multiple middleware that each transform the response."""
class PrefixMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
return AIMessage(content=f"[PREFIX] {ai_message.content}")
class SuffixMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
@@ -692,7 +877,14 @@ class TestMiddlewareComposition:
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First call fails"
@@ -700,17 +892,26 @@ class TestMiddlewareComposition:
return super()._generate(messages, **kwargs)
class RetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except Exception:
return handler(request)
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
assert isinstance(ai_message.content, str)
return AIMessage(content=ai_message.content.upper())
model = FailOnceThenSucceed(messages=iter([AIMessage(content="success")]))
@@ -728,14 +929,22 @@ class TestMiddlewareComposition:
model_calls = []
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("outer-before")
result = handler(request)
execution_order.append("outer-after")
return result
class MiddleRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("middle-before")
# Always retry once (call handler twice)
result = handler(request)
@@ -745,14 +954,25 @@ class TestMiddlewareComposition:
return result
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("inner-before")
result = handler(request)
execution_order.append("inner-after")
return result
class TrackingModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
model_calls.append(len(messages))
return super()._generate(messages, **kwargs)
@@ -789,7 +1009,10 @@ class TestWrapModelCallDecorator:
"""Test basic decorator usage without parameters."""
@wrap_model_call
def passthrough_middleware(request, handler):
def passthrough_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
# Should return an AgentMiddleware instance
@@ -807,7 +1030,10 @@ class TestWrapModelCallDecorator:
"""Test decorator with custom middleware name."""
@wrap_model_call(name="CustomMiddleware")
def my_middleware(request, handler):
def my_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
assert isinstance(my_middleware, AgentMiddleware)
@@ -818,7 +1044,14 @@ class TestWrapModelCallDecorator:
call_count = {"value": 0}
class FailOnceThenSucceed(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First call fails"
@@ -826,7 +1059,10 @@ class TestWrapModelCallDecorator:
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_once(request, handler):
def retry_once(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except Exception:
@@ -845,10 +1081,14 @@ class TestWrapModelCallDecorator:
"""Test decorator for rewriting responses."""
@wrap_model_call
def uppercase_responses(request, handler):
def uppercase_responses(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
result = handler(request)
# result is ModelResponse, extract AIMessage from it
ai_message = result.result[0]
assert isinstance(ai_message.content, str)
return AIMessage(content=ai_message.content.upper())
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
@@ -862,12 +1102,22 @@ class TestWrapModelCallDecorator:
"""Test decorator for error recovery."""
class AlwaysFailModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Model error"
raise ValueError(msg)
@wrap_model_call
def error_to_fallback(request, handler):
def error_to_fallback(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
try:
return handler(request)
except Exception:
@@ -885,7 +1135,10 @@ class TestWrapModelCallDecorator:
state_values = []
@wrap_model_call
def log_state(request, handler):
def log_state(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
state_values.append(request.state.get("messages"))
return handler(request)
@@ -904,14 +1157,20 @@ class TestWrapModelCallDecorator:
execution_order = []
@wrap_model_call
def outer_middleware(request, handler):
def outer_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("outer-before")
result = handler(request)
execution_order.append("outer-after")
return result
@wrap_model_call
def inner_middleware(request, handler):
def inner_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("inner-before")
result = handler(request)
execution_order.append("inner-after")
@@ -933,11 +1192,14 @@ class TestWrapModelCallDecorator:
"""Test decorator with custom state schema."""
class CustomState(TypedDict):
messages: list
messages: list[Any]
custom_field: str
@wrap_model_call(state_schema=CustomState)
def middleware_with_schema(request, handler):
def middleware_with_schema(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
assert isinstance(middleware_with_schema, AgentMiddleware)
@@ -953,7 +1215,10 @@ class TestWrapModelCallDecorator:
return f"Result: {query}"
@wrap_model_call(tools=[test_tool])
def middleware_with_tools(request, handler):
def middleware_with_tools(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
assert isinstance(middleware_with_tools, AgentMiddleware)
@@ -965,12 +1230,18 @@ class TestWrapModelCallDecorator:
# Without parentheses
@wrap_model_call
def middleware_no_parens(request, handler):
def middleware_no_parens(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
# With parentheses
@wrap_model_call()
def middleware_with_parens(request, handler):
def middleware_with_parens(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
assert isinstance(middleware_no_parens, AgentMiddleware)
@@ -980,7 +1251,10 @@ class TestWrapModelCallDecorator:
"""Test that decorator uses function name for class name."""
@wrap_model_call
def my_custom_middleware(request, handler):
def my_custom_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(request)
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
@@ -990,14 +1264,21 @@ class TestWrapModelCallDecorator:
execution_order = []
@wrap_model_call
def decorated_middleware(request, handler):
def decorated_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("decorated-before")
result = handler(request)
execution_order.append("decorated-after")
return result
class ClassMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
execution_order.append("class-before")
result = handler(request)
execution_order.append("class-after")
@@ -1025,7 +1306,14 @@ class TestWrapModelCallDecorator:
call_count = {"value": 0}
class UnreliableModel(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] <= 2:
msg = f"Attempt {call_count['value']} failed"
@@ -1033,7 +1321,10 @@ class TestWrapModelCallDecorator:
return super()._generate(messages, **kwargs)
@wrap_model_call
def retry_with_tracking(request, handler):
def retry_with_tracking(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
max_retries = 3
for attempt in range(max_retries):
attempts.append(attempt + 1)
@@ -1059,7 +1350,10 @@ class TestWrapModelCallDecorator:
modified_prompts = []
@wrap_model_call
def add_system_prompt(request, handler):
def add_system_prompt(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Modify request to add system prompt
modified_request = ModelRequest(
messages=request.messages,
@@ -1068,7 +1362,7 @@ class TestWrapModelCallDecorator:
tool_choice=request.tool_choice,
tools=request.tools,
response_format=request.response_format,
state={},
state=AgentState[Any](messages=[]),
runtime=None,
)
modified_prompts.append(modified_request.system_prompt)
@@ -1090,7 +1384,11 @@ class TestAsyncWrapModelCall:
log = []
class LoggingMiddleware(AgentMiddleware):
async def awrap_model_call(self, request, handler):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
log.append("before")
result = await handler(request)
log.append("after")
@@ -1109,7 +1407,14 @@ class TestAsyncWrapModelCall:
call_count = {"value": 0}
class AsyncFailOnceThenSucceed(GenericFakeChatModel):
async def _agenerate(self, messages, **kwargs):
@override
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First async call fails"
@@ -1117,7 +1422,11 @@ class TestAsyncWrapModelCall:
return await super()._agenerate(messages, **kwargs)
class RetryMiddleware(AgentMiddleware):
async def awrap_model_call(self, request, handler):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
try:
return await handler(request)
except Exception:
@@ -1136,7 +1445,10 @@ class TestAsyncWrapModelCall:
call_log = []
@wrap_model_call
async def logging_middleware(request, handler):
async def logging_middleware(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
call_log.append("before")
result = await handler(request)
call_log.append("after")
@@ -1161,8 +1473,8 @@ class TestSyncAsyncInterop:
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(request)
agent = create_agent(
@@ -1180,25 +1492,27 @@ class TestSyncAsyncInterop:
calls = []
class MixedMiddleware(AgentMiddleware):
def before_model(self, state, runtime) -> None:
@override
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None:
calls.append("MixedMiddleware.before_model")
async def abefore_model(self, state, runtime) -> None:
@override
async def abefore_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None:
calls.append("MixedMiddleware.abefore_model")
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
calls.append("MixedMiddleware.wrap_model_call")
return handler(request)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
calls.append("MixedMiddleware.awrap_model_call")
return await handler(request)
@@ -1226,7 +1540,11 @@ class TestEdgeCases:
modified_messages = []
class RequestModifyingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Add a system message to the request
modified_request = request
modified_messages.append(len(modified_request.messages))
@@ -1244,7 +1562,11 @@ class TestEdgeCases:
attempts = []
class MultiModelRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
attempts.append("first-attempt")
try:
return handler(request)
@@ -1255,7 +1577,14 @@ class TestEdgeCases:
call_count = {"value": 0}
class FailFirstSucceedSecond(GenericFakeChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
call_count["value"] += 1
if call_count["value"] == 1:
msg = "First fails"