mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
chore(langchain): fix types in test_wrap_model_call (#34573)
This commit is contained in:
committed by
GitHub
parent
0c7b7e045d
commit
f10225184d
@@ -7,18 +7,27 @@ This module tests the wrap_model_call functionality in three forms:
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.tools import tool
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.runtime import Runtime
|
||||
from typing_extensions import TypedDict, override
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
wrap_model_call,
|
||||
)
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
@@ -31,7 +40,11 @@ class TestBasicWrapModelCall:
|
||||
"""Test middleware that simply passes through without modification."""
|
||||
|
||||
class PassthroughMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
@@ -47,7 +60,11 @@ class TestBasicWrapModelCall:
|
||||
call_log = []
|
||||
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
call_log.append("before")
|
||||
result = handler(request)
|
||||
call_log.append("after")
|
||||
@@ -65,11 +82,15 @@ class TestBasicWrapModelCall:
|
||||
"""Test middleware that counts model calls."""
|
||||
|
||||
class CountingMiddleware(AgentMiddleware):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.call_count = 0
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
self.call_count += 1
|
||||
return handler(request)
|
||||
|
||||
@@ -90,7 +111,14 @@ class TestRetryLogic:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First call fails"
|
||||
@@ -98,11 +126,15 @@ class TestRetryLogic:
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
class RetryOnceMiddleware(AgentMiddleware):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.retry_count = 0
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
@@ -122,17 +154,28 @@ class TestRetryLogic:
|
||||
"""Test middleware with maximum retry limit."""
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Always fails"
|
||||
raise ValueError(msg)
|
||||
|
||||
class MaxRetriesMiddleware(AgentMiddleware):
|
||||
def __init__(self, max_retries=3):
|
||||
def __init__(self, max_retries: int = 3):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.attempts = []
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
last_exception = None
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
@@ -161,16 +204,27 @@ class TestRetryLogic:
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "failing"
|
||||
|
||||
class NoRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
|
||||
@@ -184,12 +238,19 @@ class TestRetryLogic:
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Always fails"
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "always_failing"
|
||||
|
||||
class LimitedRetryMiddleware(AgentMiddleware):
|
||||
@@ -200,7 +261,11 @@ class TestRetryLogic:
|
||||
self.max_retries = max_retries
|
||||
self.attempt_count = 0
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
last_exception = None
|
||||
for _attempt in range(self.max_retries):
|
||||
self.attempt_count += 1
|
||||
@@ -235,10 +300,15 @@ class TestResponseRewriting:
|
||||
"""Test middleware that transforms response to uppercase."""
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
assert isinstance(ai_message.content, str)
|
||||
return AIMessage(content=ai_message.content.upper())
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
@@ -256,7 +326,11 @@ class TestResponseRewriting:
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
@@ -273,12 +347,17 @@ class TestResponseRewriting:
|
||||
"""Test middleware applying multiple transformations."""
|
||||
|
||||
class MultiTransformMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
|
||||
# First transformation: uppercase
|
||||
assert isinstance(ai_message.content, str)
|
||||
content = ai_message.content.upper()
|
||||
# Second transformation: add prefix and suffix
|
||||
content = f"[START] {content} [END]"
|
||||
@@ -299,12 +378,23 @@ class TestErrorHandling:
|
||||
"""Test middleware that converts errors to successful responses."""
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
class ErrorToSuccessMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
@@ -323,12 +413,23 @@ class TestErrorHandling:
|
||||
"""Test middleware that only handles specific errors."""
|
||||
|
||||
class SpecificErrorModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Network error"
|
||||
raise ConnectionError(msg)
|
||||
|
||||
class SelectiveErrorMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except ConnectionError:
|
||||
@@ -346,7 +447,11 @@ class TestErrorHandling:
|
||||
call_log = []
|
||||
|
||||
class ErrorRecoveryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
call_log.append("before-yield")
|
||||
result = handler(request)
|
||||
@@ -369,7 +474,14 @@ class TestErrorHandling:
|
||||
call_log.clear()
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -386,11 +498,15 @@ class TestShortCircuit:
|
||||
|
||||
def test_cache_short_circuit(self) -> None:
|
||||
"""Test middleware that short-circuits with cached response."""
|
||||
cache = {}
|
||||
cache: dict[str, ModelResponse] = {}
|
||||
model_calls = []
|
||||
|
||||
class CachingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Simple cache key based on last message
|
||||
cache_key = str(request.messages[-1].content) if request.messages else ""
|
||||
|
||||
@@ -403,7 +519,14 @@ class TestShortCircuit:
|
||||
return result
|
||||
|
||||
class TrackingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
model_calls.append(len(messages))
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@@ -445,7 +568,11 @@ class TestRequestModification:
|
||||
super().__init__()
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
model=request.model,
|
||||
@@ -482,7 +609,11 @@ class TestStateAndRuntime:
|
||||
state_values = []
|
||||
|
||||
class StateAwareMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Access state from request
|
||||
state_values.append(
|
||||
{
|
||||
@@ -504,7 +635,11 @@ class TestStateAndRuntime:
|
||||
"""Test middleware that tracks retry count in state."""
|
||||
|
||||
class StateTrackingRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
max_retries = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
@@ -517,7 +652,14 @@ class TestStateAndRuntime:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First fails"
|
||||
@@ -541,14 +683,22 @@ class TestMiddlewareComposition:
|
||||
execution_order = []
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("outer-before")
|
||||
response = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("inner-before")
|
||||
response = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
@@ -572,21 +722,33 @@ class TestMiddlewareComposition:
|
||||
execution_order = []
|
||||
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
execution_order.append("third-after")
|
||||
@@ -617,7 +779,14 @@ class TestMiddlewareComposition:
|
||||
log = []
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First call fails"
|
||||
@@ -625,14 +794,22 @@ class TestMiddlewareComposition:
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
log.append("logging-before")
|
||||
result = handler(request)
|
||||
log.append("logging-after")
|
||||
return result
|
||||
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
log.append("retry-before")
|
||||
try:
|
||||
result = handler(request)
|
||||
@@ -664,14 +841,22 @@ class TestMiddlewareComposition:
|
||||
"""Test multiple middleware that each transform the response."""
|
||||
|
||||
class PrefixMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
return AIMessage(content=f"[PREFIX] {ai_message.content}")
|
||||
|
||||
class SuffixMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
@@ -692,7 +877,14 @@ class TestMiddlewareComposition:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First call fails"
|
||||
@@ -700,17 +892,26 @@ class TestMiddlewareComposition:
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
return handler(request)
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
assert isinstance(ai_message.content, str)
|
||||
return AIMessage(content=ai_message.content.upper())
|
||||
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="success")]))
|
||||
@@ -728,14 +929,22 @@ class TestMiddlewareComposition:
|
||||
model_calls = []
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
class MiddleRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("middle-before")
|
||||
# Always retry once (call handler twice)
|
||||
result = handler(request)
|
||||
@@ -745,14 +954,25 @@ class TestMiddlewareComposition:
|
||||
return result
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
return result
|
||||
|
||||
class TrackingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
model_calls.append(len(messages))
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@@ -789,7 +1009,10 @@ class TestWrapModelCallDecorator:
|
||||
"""Test basic decorator usage without parameters."""
|
||||
|
||||
@wrap_model_call
|
||||
def passthrough_middleware(request, handler):
|
||||
def passthrough_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
# Should return an AgentMiddleware instance
|
||||
@@ -807,7 +1030,10 @@ class TestWrapModelCallDecorator:
|
||||
"""Test decorator with custom middleware name."""
|
||||
|
||||
@wrap_model_call(name="CustomMiddleware")
|
||||
def my_middleware(request, handler):
|
||||
def my_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(my_middleware, AgentMiddleware)
|
||||
@@ -818,7 +1044,14 @@ class TestWrapModelCallDecorator:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailOnceThenSucceed(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First call fails"
|
||||
@@ -826,7 +1059,10 @@ class TestWrapModelCallDecorator:
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_once(request, handler):
|
||||
def retry_once(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
@@ -845,10 +1081,14 @@ class TestWrapModelCallDecorator:
|
||||
"""Test decorator for rewriting responses."""
|
||||
|
||||
@wrap_model_call
|
||||
def uppercase_responses(request, handler):
|
||||
def uppercase_responses(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
result = handler(request)
|
||||
# result is ModelResponse, extract AIMessage from it
|
||||
ai_message = result.result[0]
|
||||
assert isinstance(ai_message.content, str)
|
||||
return AIMessage(content=ai_message.content.upper())
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
@@ -862,12 +1102,22 @@ class TestWrapModelCallDecorator:
|
||||
"""Test decorator for error recovery."""
|
||||
|
||||
class AlwaysFailModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
@wrap_model_call
|
||||
def error_to_fallback(request, handler):
|
||||
def error_to_fallback(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
@@ -885,7 +1135,10 @@ class TestWrapModelCallDecorator:
|
||||
state_values = []
|
||||
|
||||
@wrap_model_call
|
||||
def log_state(request, handler):
|
||||
def log_state(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
state_values.append(request.state.get("messages"))
|
||||
return handler(request)
|
||||
|
||||
@@ -904,14 +1157,20 @@ class TestWrapModelCallDecorator:
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def outer_middleware(request, handler):
|
||||
def outer_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
@wrap_model_call
|
||||
def inner_middleware(request, handler):
|
||||
def inner_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
execution_order.append("inner-after")
|
||||
@@ -933,11 +1192,14 @@ class TestWrapModelCallDecorator:
|
||||
"""Test decorator with custom state schema."""
|
||||
|
||||
class CustomState(TypedDict):
|
||||
messages: list
|
||||
messages: list[Any]
|
||||
custom_field: str
|
||||
|
||||
@wrap_model_call(state_schema=CustomState)
|
||||
def middleware_with_schema(request, handler):
|
||||
def middleware_with_schema(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_schema, AgentMiddleware)
|
||||
@@ -953,7 +1215,10 @@ class TestWrapModelCallDecorator:
|
||||
return f"Result: {query}"
|
||||
|
||||
@wrap_model_call(tools=[test_tool])
|
||||
def middleware_with_tools(request, handler):
|
||||
def middleware_with_tools(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_with_tools, AgentMiddleware)
|
||||
@@ -965,12 +1230,18 @@ class TestWrapModelCallDecorator:
|
||||
|
||||
# Without parentheses
|
||||
@wrap_model_call
|
||||
def middleware_no_parens(request, handler):
|
||||
def middleware_no_parens(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
# With parentheses
|
||||
@wrap_model_call()
|
||||
def middleware_with_parens(request, handler):
|
||||
def middleware_with_parens(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(middleware_no_parens, AgentMiddleware)
|
||||
@@ -980,7 +1251,10 @@ class TestWrapModelCallDecorator:
|
||||
"""Test that decorator uses function name for class name."""
|
||||
|
||||
@wrap_model_call
|
||||
def my_custom_middleware(request, handler):
|
||||
def my_custom_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(request)
|
||||
|
||||
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
|
||||
@@ -990,14 +1264,21 @@ class TestWrapModelCallDecorator:
|
||||
execution_order = []
|
||||
|
||||
@wrap_model_call
|
||||
def decorated_middleware(request, handler):
|
||||
def decorated_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("decorated-before")
|
||||
result = handler(request)
|
||||
execution_order.append("decorated-after")
|
||||
return result
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
execution_order.append("class-before")
|
||||
result = handler(request)
|
||||
execution_order.append("class-after")
|
||||
@@ -1025,7 +1306,14 @@ class TestWrapModelCallDecorator:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class UnreliableModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] <= 2:
|
||||
msg = f"Attempt {call_count['value']} failed"
|
||||
@@ -1033,7 +1321,10 @@ class TestWrapModelCallDecorator:
|
||||
return super()._generate(messages, **kwargs)
|
||||
|
||||
@wrap_model_call
|
||||
def retry_with_tracking(request, handler):
|
||||
def retry_with_tracking(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
attempts.append(attempt + 1)
|
||||
@@ -1059,7 +1350,10 @@ class TestWrapModelCallDecorator:
|
||||
modified_prompts = []
|
||||
|
||||
@wrap_model_call
|
||||
def add_system_prompt(request, handler):
|
||||
def add_system_prompt(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
messages=request.messages,
|
||||
@@ -1068,7 +1362,7 @@ class TestWrapModelCallDecorator:
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
response_format=request.response_format,
|
||||
state={},
|
||||
state=AgentState[Any](messages=[]),
|
||||
runtime=None,
|
||||
)
|
||||
modified_prompts.append(modified_request.system_prompt)
|
||||
@@ -1090,7 +1384,11 @@ class TestAsyncWrapModelCall:
|
||||
log = []
|
||||
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(self, request, handler):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
log.append("before")
|
||||
result = await handler(request)
|
||||
log.append("after")
|
||||
@@ -1109,7 +1407,14 @@ class TestAsyncWrapModelCall:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class AsyncFailOnceThenSucceed(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First async call fails"
|
||||
@@ -1117,7 +1422,11 @@ class TestAsyncWrapModelCall:
|
||||
return await super()._agenerate(messages, **kwargs)
|
||||
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(self, request, handler):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception:
|
||||
@@ -1136,7 +1445,10 @@ class TestAsyncWrapModelCall:
|
||||
call_log = []
|
||||
|
||||
@wrap_model_call
|
||||
async def logging_middleware(request, handler):
|
||||
async def logging_middleware(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
call_log.append("before")
|
||||
result = await handler(request)
|
||||
call_log.append("after")
|
||||
@@ -1161,8 +1473,8 @@ class TestSyncAsyncInterop:
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -1180,25 +1492,27 @@ class TestSyncAsyncInterop:
|
||||
calls = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
def before_model(self, state, runtime) -> None:
|
||||
@override
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None:
|
||||
calls.append("MixedMiddleware.before_model")
|
||||
|
||||
async def abefore_model(self, state, runtime) -> None:
|
||||
@override
|
||||
async def abefore_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> None:
|
||||
calls.append("MixedMiddleware.abefore_model")
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
calls.append("MixedMiddleware.wrap_model_call")
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
calls.append("MixedMiddleware.awrap_model_call")
|
||||
return await handler(request)
|
||||
|
||||
@@ -1226,7 +1540,11 @@ class TestEdgeCases:
|
||||
modified_messages = []
|
||||
|
||||
class RequestModifyingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Add a system message to the request
|
||||
modified_request = request
|
||||
modified_messages.append(len(modified_request.messages))
|
||||
@@ -1244,7 +1562,11 @@ class TestEdgeCases:
|
||||
attempts = []
|
||||
|
||||
class MultiModelRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
attempts.append("first-attempt")
|
||||
try:
|
||||
return handler(request)
|
||||
@@ -1255,7 +1577,14 @@ class TestEdgeCases:
|
||||
call_count = {"value": 0}
|
||||
|
||||
class FailFirstSucceedSecond(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] == 1:
|
||||
msg = "First fails"
|
||||
|
||||
Reference in New Issue
Block a user