mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix: deprecate setattr on ModelCallRequest (#34022)
* one alternative considered was setting `frozen=True` on the dataclass, but this is breaking, so a deprecation is a nicer approach
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -163,3 +163,6 @@ node_modules
|
||||
|
||||
prof
|
||||
virtualenv/
|
||||
scratch/
|
||||
|
||||
.langgraph_api/
|
||||
|
||||
@@ -10,6 +10,7 @@ chat model.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -238,10 +239,11 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.override(messages=edited_messages))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -266,10 +268,11 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -92,9 +92,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
@@ -127,9 +126,8 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.override(model=fallback_model))
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
@@ -194,12 +194,12 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
request.system_prompt = (
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
return handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -207,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||
request.system_prompt = (
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return await handler(request)
|
||||
return await handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
@@ -255,8 +255,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
request.tools = [*selected_tools, *provider_tools]
|
||||
return request
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
|
||||
@@ -94,6 +94,31 @@ class ModelRequest:
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Set an attribute with a deprecation warning.
|
||||
|
||||
Direct attribute assignment on `ModelRequest` is deprecated. Use the
|
||||
`override()` method instead to create a new request with modified attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
# Allow setting attributes during __init__ (when object is being constructed)
|
||||
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
|
||||
f"Use request.override({name}=...) instead to create a new request "
|
||||
f"with the modified attribute.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
@@ -446,7 +471,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
```python
|
||||
def wrap_tool_call(self, request, handler):
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
modified_call = {
|
||||
**request.tool_call,
|
||||
"args": {
|
||||
**request.tool_call["args"],
|
||||
"value": request.tool_call["args"]["value"] * 2,
|
||||
},
|
||||
}
|
||||
request = request.override(tool_call=modified_call)
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
@@ -1337,7 +1369,7 @@ def dynamic_prompt(
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
request.system_prompt = prompt
|
||||
request = request.override(system_prompt=prompt)
|
||||
return await handler(request)
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
@@ -1358,7 +1390,7 @@ def dynamic_prompt(
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
prompt = cast("str", func(request))
|
||||
request.system_prompt = prompt
|
||||
request = request.override(system_prompt=prompt)
|
||||
return handler(request)
|
||||
|
||||
async def async_wrapped_from_sync(
|
||||
@@ -1368,7 +1400,7 @@ def dynamic_prompt(
|
||||
) -> ModelCallResult:
|
||||
# Delegate to sync function
|
||||
prompt = cast("str", func(request))
|
||||
request.system_prompt = prompt
|
||||
request = request.override(system_prompt=prompt)
|
||||
return await handler(request)
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
@@ -1469,7 +1501,7 @@ def wrap_model_call(
|
||||
pass
|
||||
|
||||
# Try fallback model
|
||||
request.model = fallback_model_instance
|
||||
request = request.override(model=fallback_model_instance)
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
@@ -1632,7 +1664,14 @@ def wrap_tool_call(
|
||||
```python
|
||||
@wrap_tool_call
|
||||
def modify_args(request, handler):
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
modified_call = {
|
||||
**request.tool_call,
|
||||
"args": {
|
||||
**request.tool_call["args"],
|
||||
"value": request.tool_call["args"]["value"] * 2,
|
||||
},
|
||||
}
|
||||
request = request.override(tool_call=modified_call)
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
|
||||
@@ -230,9 +230,7 @@ class TestChainModelCallHandlers:
|
||||
test_runtime = {"test": "runtime"}
|
||||
|
||||
# Create request with state and runtime
|
||||
test_request = create_test_request()
|
||||
test_request.state = test_state
|
||||
test_request.runtime = test_runtime
|
||||
test_request = create_test_request(state=test_state, runtime=test_runtime)
|
||||
result = composed(test_request, create_mock_base_handler())
|
||||
|
||||
# Both handlers should see same state and runtime
|
||||
|
||||
@@ -90,8 +90,7 @@ def test_on_model_call_decorator() -> None:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
|
||||
def custom_on_model_call(request, handler):
|
||||
request.system_prompt = "Modified"
|
||||
return handler(request)
|
||||
return handler(request.override(system_prompt="Modified"))
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_on_model_call, AgentMiddleware)
|
||||
@@ -277,8 +276,7 @@ def test_async_on_model_call_decorator() -> None:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="AsyncOnModelCall")
|
||||
async def async_on_model_call(request, handler):
|
||||
request.system_prompt = "Modified async"
|
||||
return await handler(request)
|
||||
return await handler(request.override(system_prompt="Modified async"))
|
||||
|
||||
assert isinstance(async_on_model_call, AgentMiddleware)
|
||||
assert async_on_model_call.state_schema == CustomState
|
||||
|
||||
@@ -79,8 +79,8 @@ def test_middleware_can_modify_tools() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return handler(request)
|
||||
filtered_tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return handler(request.override(tools=filtered_tools))
|
||||
|
||||
# Model will try to call tool_a
|
||||
model = FakeToolCallingModel(
|
||||
@@ -123,8 +123,7 @@ def test_unknown_tool_raises_error() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Add an unknown tool
|
||||
request.tools = request.tools + [unknown_tool]
|
||||
return handler(request)
|
||||
return handler(request.override(tools=request.tools + [unknown_tool]))
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -163,7 +162,8 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
) -> AIMessage:
|
||||
# Remove admin_tool if not admin
|
||||
if not request.state.get("is_admin", False):
|
||||
request.tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -200,7 +200,7 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
# Remove all tools
|
||||
request.tools = []
|
||||
request = request.override(tools=[])
|
||||
return handler(request)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
@@ -244,7 +244,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
@@ -257,7 +258,8 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
# Should not see tool_c here
|
||||
assert all(t.name != "tool_c" for t in request.tools)
|
||||
# Remove tool_b
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
filtered_tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
request = request.override(tools=filtered_tools)
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
|
||||
@@ -82,16 +82,23 @@ def test_no_edit_when_below_trigger() -> None:
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
# The modified request passed to handler should be the same since no edits applied
|
||||
assert modified_request is not None
|
||||
assert modified_request.messages[0].content == ""
|
||||
assert modified_request.messages[1].content == "12345"
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
def test_clear_tool_outputs_and_inputs() -> None:
|
||||
@@ -115,14 +122,19 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
assert modified_request is not None
|
||||
cleared_ai = modified_request.messages[0]
|
||||
cleared_tool = modified_request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
@@ -134,7 +146,9 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
|
||||
assert request.messages[1].content == "x" * 200
|
||||
|
||||
|
||||
def test_respects_keep_last_tool_results() -> None:
|
||||
@@ -167,21 +181,26 @@ def test_respects_keep_last_tool_results() -> None:
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
for msg in modified_request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
assert isinstance(modified_request.messages[-1], ToolMessage)
|
||||
assert modified_request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
def test_exclude_tools_prevents_clearing() -> None:
|
||||
@@ -215,14 +234,19 @@ def test_exclude_tools_prevents_clearing() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
# Call wrap_model_call which creates a new request with edits
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
assert modified_request is not None
|
||||
search_tool = modified_request.messages[1]
|
||||
calc_tool = modified_request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
@@ -249,16 +273,23 @@ async def test_no_edit_when_below_trigger_async() -> None:
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
# The modified request passed to handler should be the same since no edits applied
|
||||
assert modified_request is not None
|
||||
assert modified_request.messages[0].content == ""
|
||||
assert modified_request.messages[1].content == "12345"
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
|
||||
|
||||
async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
@@ -283,14 +314,19 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
assert modified_request is not None
|
||||
cleared_ai = modified_request.messages[0]
|
||||
cleared_tool = modified_request.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
@@ -302,7 +338,9 @@ async def test_clear_tool_outputs_and_inputs_async() -> None:
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
# Original request should be unchanged
|
||||
assert request.messages[0].tool_calls[0]["args"] == {"query": "foo"}
|
||||
assert request.messages[1].content == "x" * 200
|
||||
|
||||
|
||||
async def test_respects_keep_last_tool_results_async() -> None:
|
||||
@@ -336,21 +374,26 @@ async def test_respects_keep_last_tool_results_async() -> None:
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
for msg in modified_request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
assert isinstance(modified_request.messages[-1], ToolMessage)
|
||||
assert modified_request.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
@@ -385,14 +428,19 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call awrap_model_call which modifies the request
|
||||
# Call awrap_model_call which creates a new request with edits
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
assert modified_request is not None
|
||||
search_tool = modified_request.messages[1]
|
||||
calc_tool = modified_request.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
@@ -45,7 +46,7 @@ def test_primary_model_succeeds() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful model call
|
||||
@@ -70,7 +71,7 @@ def test_fallback_on_primary_failure() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -95,7 +96,7 @@ def test_multiple_fallbacks() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -119,7 +120,7 @@ def test_all_models_fail() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
@@ -136,7 +137,7 @@ async def test_primary_model_succeeds_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
# Simulate successful async model call
|
||||
@@ -161,7 +162,7 @@ async def test_fallback_on_primary_failure_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -186,7 +187,7 @@ async def test_multiple_fallbacks_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback1, fallback2)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -210,7 +211,7 @@ async def test_all_models_fail_async() -> None:
|
||||
|
||||
middleware = ModelFallbackMiddleware(fallback_model)
|
||||
request = _make_request()
|
||||
request.model = primary_model
|
||||
request = request.override(model=primary_model)
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
@@ -305,3 +306,46 @@ def test_model_fallback_middleware_initialization() -> None:
|
||||
# Test with multiple fallback models
|
||||
middleware = ModelFallbackMiddleware(FakeToolCallingModel(), FakeToolCallingModel())
|
||||
assert len(middleware.models) == 2
|
||||
|
||||
|
||||
def test_model_request_is_frozen() -> None:
|
||||
"""Test that ModelRequest raises deprecation warning on direct attribute assignment."""
|
||||
request = _make_request()
|
||||
new_model = GenericFakeChatModel(messages=iter([AIMessage(content="new model")]))
|
||||
|
||||
# Direct attribute assignment should raise DeprecationWarning but still work
|
||||
with pytest.warns(
|
||||
DeprecationWarning, match="Direct attribute assignment to ModelRequest.model is deprecated"
|
||||
):
|
||||
request.model = new_model # type: ignore[misc]
|
||||
|
||||
# Verify the assignment actually worked
|
||||
assert request.model == new_model
|
||||
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Direct attribute assignment to ModelRequest.system_prompt is deprecated",
|
||||
):
|
||||
request.system_prompt = "new prompt" # type: ignore[misc]
|
||||
|
||||
assert request.system_prompt == "new prompt"
|
||||
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Direct attribute assignment to ModelRequest.messages is deprecated",
|
||||
):
|
||||
request.messages = [] # type: ignore[misc]
|
||||
|
||||
assert request.messages == []
|
||||
|
||||
# Using override method should work without warnings
|
||||
request2 = _make_request()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error") # Turn warnings into errors
|
||||
new_request = request2.override(model=new_model, system_prompt="override prompt")
|
||||
|
||||
assert new_request.model == new_model
|
||||
assert new_request.system_prompt == "override prompt"
|
||||
# Original request should be unchanged
|
||||
assert request2.model != new_model
|
||||
assert request2.system_prompt != "override prompt"
|
||||
|
||||
@@ -83,14 +83,21 @@ def test_adds_system_prompt_when_none_exists() -> None:
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
# System prompt should be set in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
|
||||
def test_appends_to_existing_system_prompt() -> None:
|
||||
@@ -99,16 +106,23 @@ def test_appends_to_existing_system_prompt() -> None:
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
# System prompt should contain both in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert existing_prompt in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert captured_request.system_prompt.startswith(existing_prompt)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == existing_prompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -137,13 +151,20 @@ def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix)
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt.startswith(expected_prompt_prefix)
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt.startswith(expected_prompt_prefix)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == original_prompt
|
||||
|
||||
|
||||
def test_custom_system_prompt() -> None:
|
||||
@@ -152,13 +173,20 @@ def test_custom_system_prompt() -> None:
|
||||
middleware = TodoListMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
# Should use custom prompt in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
|
||||
def test_todo_middleware_custom_system_prompt() -> None:
|
||||
@@ -181,13 +209,20 @@ def test_todo_middleware_custom_system_prompt() -> None:
|
||||
runtime=cast(Runtime, object()),
|
||||
)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == "Original prompt"
|
||||
|
||||
|
||||
def test_custom_tool_description() -> None:
|
||||
@@ -235,13 +270,20 @@ def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
captured_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt == custom_system_prompt
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_system_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
# Verify tool description
|
||||
assert len(middleware.tools) == 1
|
||||
@@ -390,14 +432,21 @@ async def test_adds_system_prompt_when_none_exists_async() -> None:
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
captured_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should be set
|
||||
assert request.system_prompt is not None
|
||||
assert "write_todos" in request.system_prompt
|
||||
# System prompt should be set in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
|
||||
async def test_appends_to_existing_system_prompt_async() -> None:
|
||||
@@ -406,16 +455,23 @@ async def test_appends_to_existing_system_prompt_async() -> None:
|
||||
middleware = TodoListMiddleware()
|
||||
request = _make_request(system_prompt=existing_prompt)
|
||||
|
||||
captured_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# System prompt should contain both
|
||||
assert request.system_prompt is not None
|
||||
assert existing_prompt in request.system_prompt
|
||||
assert "write_todos" in request.system_prompt
|
||||
assert request.system_prompt.startswith(existing_prompt)
|
||||
# System prompt should contain both in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert existing_prompt in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert captured_request.system_prompt.startswith(existing_prompt)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == existing_prompt
|
||||
|
||||
|
||||
async def test_custom_system_prompt_async() -> None:
|
||||
@@ -424,13 +480,20 @@ async def test_custom_system_prompt_async() -> None:
|
||||
middleware = TodoListMiddleware(system_prompt=custom_prompt)
|
||||
request = _make_request(system_prompt=None)
|
||||
|
||||
captured_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured_request
|
||||
captured_request = req
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
# Should use custom prompt
|
||||
assert request.system_prompt == custom_prompt
|
||||
# Should use custom prompt in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
|
||||
async def test_handler_called_with_modified_request_async() -> None:
|
||||
|
||||
@@ -784,8 +784,7 @@ class TestDynamicModelWithResponseFormat:
|
||||
handler: Callable[[ModelRequest], CoreAIMessage],
|
||||
) -> CoreAIMessage:
|
||||
# Replace the model with our custom test model
|
||||
request.model = model
|
||||
return handler(request)
|
||||
return handler(request.override(model=model))
|
||||
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
@@ -94,14 +94,6 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
)
|
||||
return messages_count >= self.min_messages_to_cache
|
||||
|
||||
def _apply_cache_control(self, request: ModelRequest) -> None:
|
||||
"""Apply cache control settings to the request.
|
||||
|
||||
Args:
|
||||
request: The model request to modify.
|
||||
"""
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -119,8 +111,12 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return handler(request)
|
||||
|
||||
self._apply_cache_control(request)
|
||||
return handler(request)
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return handler(request.override(model_settings=new_model_settings))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -139,5 +135,9 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return await handler(request)
|
||||
|
||||
self._apply_cache_control(request)
|
||||
return await handler(request)
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return await handler(request.override(model_settings=new_model_settings))
|
||||
|
||||
@@ -82,12 +82,17 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
modified_request: ModelRequest | None = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
assert modified_request is not None
|
||||
assert modified_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
@@ -162,13 +167,18 @@ async def test_anthropic_prompt_caching_middleware_async() -> None:
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
modified_request: ModelRequest | None = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {
|
||||
assert modified_request is not None
|
||||
assert modified_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||
}
|
||||
|
||||
@@ -237,13 +247,18 @@ async def test_anthropic_prompt_caching_middleware_async_min_messages() -> None:
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
modified_request: ModelRequest | None = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Cache control should NOT be added when message count is below minimum
|
||||
assert fake_request.model_settings == {}
|
||||
assert modified_request is not None
|
||||
assert modified_request.model_settings == {}
|
||||
|
||||
|
||||
async def test_anthropic_prompt_caching_middleware_async_with_system_prompt() -> None:
|
||||
@@ -268,13 +283,18 @@ async def test_anthropic_prompt_caching_middleware_async_with_system_prompt() ->
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
modified_request: ModelRequest | None = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Cache control should be added when system prompt pushes count to minimum
|
||||
assert fake_request.model_settings == {
|
||||
assert modified_request is not None
|
||||
assert modified_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "1h"}
|
||||
}
|
||||
|
||||
@@ -300,12 +320,17 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
modified_request: ModelRequest | None = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
result = await middleware.awrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, ModelResponse)
|
||||
# Check that model_settings were added with default values
|
||||
assert fake_request.model_settings == {
|
||||
assert modified_request is not None
|
||||
assert modified_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
1026
libs/partners/anthropic/uv.lock
generated
1026
libs/partners/anthropic/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user