fix: deprecate setattr on ModelCallRequest (#34022)

* one alternative considered was setting `frozen=True` on the dataclass,
but this is breaking, so a deprecation is a nicer approach
This commit is contained in:
Sydney Runkle
2025-11-19 11:08:55 -05:00
committed by GitHub
parent 328ba36601
commit b7d1831f9d
16 changed files with 866 additions and 611 deletions

3
.gitignore vendored
View File

@@ -163,3 +163,6 @@ node_modules
prof
virtualenv/
scratch/
.langgraph_api/

View File

@@ -10,6 +10,7 @@ chat model.
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Literal
@@ -238,10 +239,11 @@ class ContextEditingMiddleware(AgentMiddleware):
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
edit.apply(edited_messages, count_tokens=count_tokens)
return handler(request)
return handler(request.override(messages=edited_messages))
async def awrap_model_call(
self,
@@ -266,10 +268,11 @@ class ContextEditingMiddleware(AgentMiddleware):
system_msg + list(messages), request.tools
)
edited_messages = deepcopy(list(request.messages))
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
edit.apply(edited_messages, count_tokens=count_tokens)
return await handler(request)
return await handler(request.override(messages=edited_messages))
__all__ = [

View File

@@ -92,9 +92,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return handler(request)
return handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
last_exception = e
continue
@@ -127,9 +126,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
try:
return await handler(request)
return await handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
last_exception = e
continue

View File

@@ -194,12 +194,12 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt."""
request.system_prompt = (
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.override(system_prompt=new_system_prompt))
async def awrap_model_call(
self,
@@ -207,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt (async version)."""
request.system_prompt = (
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
return await handler(request)
return await handler(request.override(system_prompt=new_system_prompt))

View File

@@ -255,8 +255,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
request.tools = [*selected_tools, *provider_tools]
return request
return request.override(tools=[*selected_tools, *provider_tools])
def wrap_model_call(
self,

View File

@@ -94,6 +94,31 @@ class ModelRequest:
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value: Any) -> None:
"""Set an attribute with a deprecation warning.
Direct attribute assignment on `ModelRequest` is deprecated. Use the
`override()` method instead to create a new request with modified attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
import warnings
# Allow setting attributes during __init__ (when object is being constructed)
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
object.__setattr__(self, name, value)
else:
warnings.warn(
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
f"Use request.override({name}=...) instead to create a new request "
f"with the modified attribute.",
DeprecationWarning,
stacklevel=2,
)
object.__setattr__(self, name, value)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
"""Replace the request with a new request with the given overrides.
@@ -446,7 +471,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
```python
def wrap_tool_call(self, request, handler):
request.tool_call["args"]["value"] *= 2
modified_call = {
**request.tool_call,
"args": {
**request.tool_call["args"],
"value": request.tool_call["args"]["value"] * 2,
},
}
request = request.override(tool_call=modified_call)
return handler(request)
```
@@ -1337,7 +1369,7 @@ def dynamic_prompt(
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
request.system_prompt = prompt
request = request.override(system_prompt=prompt)
return await handler(request)
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1358,7 +1390,7 @@ def dynamic_prompt(
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
prompt = cast("str", func(request))
request.system_prompt = prompt
request = request.override(system_prompt=prompt)
return handler(request)
async def async_wrapped_from_sync(
@@ -1368,7 +1400,7 @@ def dynamic_prompt(
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("str", func(request))
request.system_prompt = prompt
request = request.override(system_prompt=prompt)
return await handler(request)
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1469,7 +1501,7 @@ def wrap_model_call(
pass
# Try fallback model
request.model = fallback_model_instance
request = request.override(model=fallback_model_instance)
return handler(request)
```
@@ -1632,7 +1664,14 @@ def wrap_tool_call(
```python
@wrap_tool_call
def modify_args(request, handler):
request.tool_call["args"]["value"] *= 2
modified_call = {
**request.tool_call,
"args": {
**request.tool_call["args"],
"value": request.tool_call["args"]["value"] * 2,
},
}
request = request.override(tool_call=modified_call)
return handler(request)
```

View File

@@ -230,9 +230,7 @@ class TestChainModelCallHandlers:
test_runtime = {"test": "runtime"}
# Create request with state and runtime
test_request = create_test_request()
test_request.state = test_state
test_request.runtime = test_runtime
test_request = create_test_request(state=test_state, runtime=test_runtime)
result = composed(test_request, create_mock_base_handler())
# Both handlers should see same state and runtime

View File

@@ -90,8 +90,7 @@ def test_on_model_call_decorator() -> None:
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
def custom_on_model_call(request, handler):
request.system_prompt = "Modified"
return handler(request)
return handler(request.override(system_prompt="Modified"))
# Verify all options were applied
assert isinstance(custom_on_model_call, AgentMiddleware)
@@ -277,8 +276,7 @@ def test_async_on_model_call_decorator() -> None:
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="AsyncOnModelCall")
async def async_on_model_call(request, handler):
request.system_prompt = "Modified async"
return await handler(request)
return await handler(request.override(system_prompt="Modified async"))
assert isinstance(async_on_model_call, AgentMiddleware)
assert async_on_model_call.state_schema == CustomState

View File

@@ -79,8 +79,8 @@ def test_middleware_can_modify_tools() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Only allow tool_a and tool_b
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return handler(request)
filtered_tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return handler(request.override(tools=filtered_tools))
# Model will try to call tool_a
model = FakeToolCallingModel(
@@ -123,8 +123,7 @@ def test_unknown_tool_raises_error() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Add an unknown tool
request.tools = request.tools + [unknown_tool]
return handler(request)
return handler(request.override(tools=request.tools + [unknown_tool]))
agent = create_agent(
model=FakeToolCallingModel(),
@@ -163,7 +162,8 @@ def test_middleware_can_add_and_remove_tools() -> None:
) -> AIMessage:
# Remove admin_tool if not admin
if not request.state.get("is_admin", False):
request.tools = [t for t in request.tools if t.name != "admin_tool"]
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
request = request.override(tools=filtered_tools)
return handler(request)
model = FakeToolCallingModel()
@@ -200,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
# Remove all tools
request.tools = []
request = request.override(tools=[])
return handler(request)
model = FakeToolCallingModel()
@@ -244,7 +244,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
# Remove tool_c
request.tools = [t for t in request.tools if t.name != "tool_c"]
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
request = request.override(tools=filtered_tools)
return handler(request)
class SecondMiddleware(AgentMiddleware):
@@ -257,7 +258,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
# Should not see tool_c here
assert all(t.name != "tool_c" for t in request.tools)
# Remove tool_b
request.tools = [t for t in request.tools if t.name != "tool_b"]
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
request = request.override(tools=filtered_tools)
return handler(request)
agent = create_agent(

View File

@@ -82,16 +82,23 @@ def test_no_edit_when_below_trigger() -> None:
edits=[ClearToolUsesEdit(trigger=50)],
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request
middleware.wrap_model_call(request, mock_handler)
# The request should have been modified in place
# The modified request passed to handler should be the same since no edits applied
assert modified_request is not None
assert modified_request.messages[0].content == ""
assert modified_request.messages[1].content == "12345"
# Original request should be unchanged
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
def test_clear_tool_outputs_and_inputs() -> None:
@@ -115,14 +122,19 @@ def test_clear_tool_outputs_and_inputs() -> None:
)
middleware = ContextEditingMiddleware(edits=[edit])
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert modified_request is not None
cleared_ai = modified_request.messages[0]
cleared_tool = modified_request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
@@ -134,7 +146,9 @@ def test_clear_tool_outputs_and_inputs() -> None:
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200
def test_respects_keep_last_tool_results() -> None:
@@ -167,21 +181,26 @@ def test_respects_keep_last_tool_results() -> None:
token_count_method="model",
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
cleared_messages = [
msg
for msg in request.messages
for msg in modified_request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
assert isinstance(modified_request.messages[-1], ToolMessage)
assert modified_request.messages[-1].content != "[cleared]"
def test_exclude_tools_prevents_clearing() -> None:
@@ -215,14 +234,19 @@ def test_exclude_tools_prevents_clearing() -> None:
],
)
modified_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call wrap_model_call which modifies the request
# Call wrap_model_call which creates a new request with edits
middleware.wrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert modified_request is not None
search_tool = modified_request.messages[1]
calc_tool = modified_request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20
@@ -249,16 +273,23 @@ async def test_no_edit_when_below_trigger_async() -> None:
edits=[ClearToolUsesEdit(trigger=50)],
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request
await middleware.awrap_model_call(request, mock_handler)
# The request should have been modified in place
# The modified request passed to handler should be the same since no edits applied
assert modified_request is not None
assert modified_request.messages[0].content == ""
assert modified_request.messages[1].content == "12345"
# Original request should be unchanged
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
async def test_clear_tool_outputs_and_inputs_async() -> None:
@@ -283,14 +314,19 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
)
middleware = ContextEditingMiddleware(edits=[edit])
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
assert modified_request is not None
cleared_ai = modified_request.messages[0]
cleared_tool = modified_request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
@@ -302,7 +338,9 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
# Original request should be unchanged
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
assert request.messages[1].content == "x" * 200
async def test_respects_keep_last_tool_results_async() -> None:
@@ -336,21 +374,26 @@ async def test_respects_keep_last_tool_results_async() -> None:
token_count_method="model",
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
assert modified_request is not None
cleared_messages = [
msg
for msg in request.messages
for msg in modified_request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
assert isinstance(modified_request.messages[-1], ToolMessage)
assert modified_request.messages[-1].content != "[cleared]"
async def test_exclude_tools_prevents_clearing_async() -> None:
@@ -385,14 +428,19 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
],
)
modified_request = None
async def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal modified_request
modified_request = req
return AIMessage(content="mock response")
# Call awrap_model_call which modifies the request
# Call awrap_model_call which creates a new request with edits
await middleware.awrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
assert modified_request is not None
search_tool = modified_request.messages[1]
calc_tool = modified_request.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import warnings
from typing import cast
import pytest
@@ -45,7 +46,7 @@ def test_primary_model_succeeds() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful model call
@@ -70,7 +71,7 @@ def test_fallback_on_primary_failure() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -95,7 +96,7 @@ def test_multiple_fallbacks() -> None:
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -119,7 +120,7 @@ def test_all_models_fail() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
def mock_handler(req: ModelRequest) -> ModelResponse:
result = req.model.invoke([])
@@ -136,7 +137,7 @@ async def test_primary_model_succeeds_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
# Simulate successful async model call
@@ -161,7 +162,7 @@ async def test_fallback_on_primary_failure_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -186,7 +187,7 @@ async def test_multiple_fallbacks_async() -> None:
middleware = ModelFallbackMiddleware(fallback1, fallback2)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -210,7 +211,7 @@ async def test_all_models_fail_async() -> None:
middleware = ModelFallbackMiddleware(fallback_model)
request = _make_request()
request.model = primary_model
request = request.override(model=primary_model)
async def mock_handler(req: ModelRequest) -> ModelResponse:
result = await req.model.ainvoke([])
@@ -305,3 +306,46 @@ def test_model_fallback_middleware_initialization() -> None:
# Test with multiple fallback models
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
assert len(middleware.models) == 2
def test_model_request_is_frozen() -> None:
"""Test that ModelRequest raises deprecation warning on direct attribute assignment."""
request = _make_request()
new_model = GenericFakeChatModel(messages=iter([AIMessage(content="new model")]))
# Direct attribute assignment should raise DeprecationWarning but still work
with pytest.warns(
DeprecationWarning, match="Direct attribute assignment to ModelRequest.model is deprecated"
):
request.model = new_model # type: ignore[misc]
# Verify the assignment actually worked
assert request.model == new_model
with pytest.warns(
DeprecationWarning,
match="Direct attribute assignment to ModelRequest.system_prompt is deprecated",
):
request.system_prompt = "new prompt" # type: ignore[misc]
assert request.system_prompt == "new prompt"
with pytest.warns(
DeprecationWarning,
match="Direct attribute assignment to ModelRequest.messages is deprecated",
):
request.messages = [] # type: ignore[misc]
assert request.messages == []
# Using override method should work without warnings
request2 = _make_request()
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
new_request = request2.override(model=new_model, system_prompt="override prompt")
assert new_request.model == new_model
assert new_request.system_prompt == "override prompt"
# Original request should be unchanged
assert request2.model != new_model
assert request2.system_prompt != "override prompt"

View File

@@ -83,14 +83,21 @@ def test_adds_system_prompt_when_none_exists() -> None:
middleware = TodoListMiddleware()
request = _make_request(system_prompt=None)
captured_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
# System prompt should be set in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert "write_todos" in captured_request.system_prompt
# Original request should be unchanged
assert request.system_prompt is None
def test_appends_to_existing_system_prompt() -> None:
@@ -99,16 +106,23 @@ def test_appends_to_existing_system_prompt() -> None:
middleware = TodoListMiddleware()
request = _make_request(system_prompt=existing_prompt)
captured_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
# System prompt should contain both in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert existing_prompt in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt
assert captured_request.system_prompt.startswith(existing_prompt)
# Original request should be unchanged
assert request.system_prompt == existing_prompt
@pytest.mark.parametrize(
@@ -137,13 +151,20 @@ def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix)
model_settings={},
)
captured_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal captured_request
captured_request = req
return AIMessage(content="mock response")
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt.startswith(expected_prompt_prefix)
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt.startswith(expected_prompt_prefix)
# Original request should be unchanged
assert request.system_prompt == original_prompt
def test_custom_system_prompt() -> None:
@@ -152,13 +173,20 @@ def test_custom_system_prompt() -> None:
middleware = TodoListMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
captured_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
middleware.wrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
# Should use custom prompt in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt == custom_prompt
# Original request should be unchanged
assert request.system_prompt is None
def test_todo_middleware_custom_system_prompt() -> None:
@@ -181,13 +209,20 @@ def test_todo_middleware_custom_system_prompt() -> None:
runtime=cast(Runtime, object()),
)
captured_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal captured_request
captured_request = req
return AIMessage(content="mock response")
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
# Original request should be unchanged
assert request.system_prompt == "Original prompt"
def test_custom_tool_description() -> None:
@@ -235,13 +270,20 @@ def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
model_settings={},
)
captured_request = None
def mock_handler(req: ModelRequest) -> AIMessage:
nonlocal captured_request
captured_request = req
return AIMessage(content="mock response")
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt == custom_system_prompt
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt == custom_system_prompt
# Original request should be unchanged
assert request.system_prompt is None
# Verify tool description
assert len(middleware.tools) == 1
@@ -390,14 +432,21 @@ async def test_adds_system_prompt_when_none_exists_async() -> None:
middleware = TodoListMiddleware()
request = _make_request(system_prompt=None)
captured_request = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should be set
assert request.system_prompt is not None
assert "write_todos" in request.system_prompt
# System prompt should be set in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert "write_todos" in captured_request.system_prompt
# Original request should be unchanged
assert request.system_prompt is None
async def test_appends_to_existing_system_prompt_async() -> None:
@@ -406,16 +455,23 @@ async def test_appends_to_existing_system_prompt_async() -> None:
middleware = TodoListMiddleware()
request = _make_request(system_prompt=existing_prompt)
captured_request = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# System prompt should contain both
assert request.system_prompt is not None
assert existing_prompt in request.system_prompt
assert "write_todos" in request.system_prompt
assert request.system_prompt.startswith(existing_prompt)
# System prompt should contain both in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert existing_prompt in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt
assert captured_request.system_prompt.startswith(existing_prompt)
# Original request should be unchanged
assert request.system_prompt == existing_prompt
async def test_custom_system_prompt_async() -> None:
@@ -424,13 +480,20 @@ async def test_custom_system_prompt_async() -> None:
middleware = TodoListMiddleware(system_prompt=custom_prompt)
request = _make_request(system_prompt=None)
captured_request = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal captured_request
captured_request = req
return ModelResponse(result=[AIMessage(content="response")])
await middleware.awrap_model_call(request, mock_handler)
# Should use custom prompt
assert request.system_prompt == custom_prompt
# Should use custom prompt in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt == custom_prompt
# Original request should be unchanged
assert request.system_prompt is None
async def test_handler_called_with_modified_request_async() -> None:

View File

@@ -784,8 +784,7 @@ class TestDynamicModelWithResponseFormat:
handler: Callable[[ModelRequest], CoreAIMessage],
) -> CoreAIMessage:
# Replace the model with our custom test model
request.model = model
return handler(request)
return handler(request.override(model=model))
# Track which model is checked for provider strategy support
calls = []

View File

@@ -94,14 +94,6 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
)
return messages_count >= self.min_messages_to_cache
def _apply_cache_control(self, request: ModelRequest) -> None:
"""Apply cache control settings to the request.
Args:
request: The model request to modify.
"""
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
def wrap_model_call(
self,
request: ModelRequest,
@@ -119,8 +111,12 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return handler(request)
self._apply_cache_control(request)
return handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return handler(request.override(model_settings=new_model_settings))
async def awrap_model_call(
self,
@@ -139,5 +135,9 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return await handler(request)
self._apply_cache_control(request)
return await handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return await handler(request.override(model_settings=new_model_settings))

View File

@@ -82,12 +82,17 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
model_settings={},
)
modified_request: ModelRequest | None = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock response")])
middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
assert modified_request is not None
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
@@ -162,13 +167,18 @@ async def test_anthropic_prompt_caching_middleware_async() -> None:
model_settings={},
)
modified_request: ModelRequest | None = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {
assert modified_request is not None
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "1h"}
}
@@ -237,13 +247,18 @@ async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
model_settings={},
)
modified_request: ModelRequest | None = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Cache control should NOT be added when message count is below minimum
assert fake_request.model_settings == {}
assert modified_request is not None
assert modified_request.model_settings == {}
async def test_anthropic_prompt_caching_middleware_async_with_system_prompt() -> None:
@@ -268,13 +283,18 @@ async def test_anthropic_prompt_caching_middleware_async_with_system_prompt() ->
model_settings={},
)
modified_request: ModelRequest | None = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Cache control should be added when system prompt pushes count to minimum
assert fake_request.model_settings == {
assert modified_request is not None
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "1h"}
}
@@ -300,12 +320,17 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
model_settings={},
)
modified_request: ModelRequest | None = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock response")])
result = await middleware.awrap_model_call(fake_request, mock_handler)
assert isinstance(result, ModelResponse)
# Check that model_settings were added with default values
assert fake_request.model_settings == {
assert modified_request is not None
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}

File diff suppressed because it is too large Load Diff