Compare commits

...

2 Commits

Author SHA1 Message Date
Eugene Yurtsev
c21b43fb4e x 2025-10-09 17:01:09 -04:00
Eugene Yurtsev
05eed19605 x 2025-10-09 16:44:57 -04:00
17 changed files with 843 additions and 548 deletions

View File

@@ -32,7 +32,9 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
JumpTo,
ModelCall,
ModelRequest,
ModelResponse,
OmitFromSchema,
PublicAgentState,
)
@@ -87,14 +89,14 @@ class _InternalModelResponse:
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], AIMessage]],
AIMessage,
[ModelRequest, Callable[[ModelCall], ModelResponse]],
ModelResponse,
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], AIMessage]],
AIMessage,
[ModelRequest, Callable[[ModelCall], ModelResponse]],
ModelResponse,
]
| None
):
@@ -141,26 +143,26 @@ def _chain_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], AIMessage]],
AIMessage,
[ModelRequest, Callable[[ModelCall], ModelResponse]],
ModelResponse,
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], AIMessage]],
AIMessage,
[ModelRequest, Callable[[ModelCall], ModelResponse]],
ModelResponse,
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], AIMessage]],
AIMessage,
[ModelRequest, Callable[[ModelCall], ModelResponse]],
ModelResponse,
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler
def inner_handler(req: ModelRequest) -> AIMessage:
return inner(req, handler)
def inner_handler(_model_call: ModelCall) -> ModelResponse:
return inner(request, handler)
# Call outer with the wrapped inner as its handler
return outer(request, inner_handler)
@@ -178,14 +180,14 @@ def _chain_model_call_handlers(
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
Awaitable[AIMessage],
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
Awaitable[AIMessage],
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
]
| None
):
@@ -205,26 +207,26 @@ def _chain_async_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
Awaitable[AIMessage],
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
Awaitable[AIMessage],
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
Awaitable[AIMessage],
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
Awaitable[ModelResponse],
]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler
async def inner_handler(req: ModelRequest) -> AIMessage:
return await inner(req, handler)
async def inner_handler(_model_call: ModelCall) -> ModelResponse:
return await inner(request, handler)
# Call outer with the wrapped inner as its handler
return await outer(request, inner_handler)
@@ -744,13 +746,13 @@ def create_agent( # noqa: PLR0915
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(model_call: ModelCall) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.
Performs auto-detection of strategy if needed based on model capabilities.
Args:
request: The model request containing model, tools, and response format.
model_call: The model call containing model, tools, and response format.
Returns:
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
@@ -765,7 +767,7 @@ def create_agent( # noqa: PLR0915
# Check if any requested tools are unknown CLIENT-SIDE tools
unknown_tool_names = []
for t in request.tools:
for t in model_call.tools:
# Only validate BaseTool instances (skip built-in dict tools)
if isinstance(t, dict):
continue
@@ -782,7 +784,7 @@ def create_agent( # noqa: PLR0915
"the 'tools' parameter\n"
"2. If using custom middleware with tools, ensure "
"they're registered via middleware.tools attribute\n"
"3. Verify that tool names in ModelRequest.tools match "
"3. Verify that tool names in ModelCall.tools match "
"the actual tool.name values\n"
"Note: Built-in provider tools (dict format) can be added dynamically."
)
@@ -790,22 +792,24 @@ def create_agent( # noqa: PLR0915
# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat | None
if isinstance(request.response_format, AutoStrategy):
if isinstance(model_call.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(model_call.model):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
effective_response_format = ProviderStrategy(
schema=model_call.response_format.schema
)
else:
# Model doesn't support provider strategy - use ToolStrategy
effective_response_format = ToolStrategy(schema=request.response_format.schema)
effective_response_format = ToolStrategy(schema=model_call.response_format.schema)
else:
# User explicitly specified a strategy - preserve it
effective_response_format = request.response_format
effective_response_format = model_call.response_format
# Build final tools list including structured output tools
# request.tools now only contains BaseTool instances (converted from callables)
# model_call.tools now only contains BaseTool instances (converted from callables)
# and dicts (built-ins)
final_tools = list(request.tools)
final_tools = list(model_call.tools)
if isinstance(effective_response_format, ToolStrategy):
# Add structured output tools to final tools list
structured_tools = [info.tool for info in structured_output_tools.values()]
@@ -816,8 +820,8 @@ def create_agent( # noqa: PLR0915
# Use provider-specific structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(
final_tools, strict=True, **kwargs, **request.model_settings
model_call.model.bind_tools(
final_tools, strict=True, **kwargs, **model_call.model_settings
),
effective_response_format,
)
@@ -839,10 +843,10 @@ def create_agent( # noqa: PLR0915
raise ValueError(msg)
# Force tool use if we have structured output tools
tool_choice = "any" if structured_output_tools else request.tool_choice
tool_choice = "any" if structured_output_tools else model_call.tool_choice
return (
request.model.bind_tools(
final_tools, tool_choice=tool_choice, **request.model_settings
model_call.model.bind_tools(
final_tools, tool_choice=tool_choice, **model_call.model_settings
),
effective_response_format,
)
@@ -850,145 +854,316 @@ def create_agent( # noqa: PLR0915
# No structured output - standard model binding
if final_tools:
return (
request.model.bind_tools(
final_tools, tool_choice=request.tool_choice, **request.model_settings
model_call.model.bind_tools(
final_tools, tool_choice=model_call.tool_choice, **model_call.model_settings
),
None,
)
return request.model.bind(**request.model_settings), None
return model_call.model.bind(**model_call.model_settings), None
def _execute_model_sync(request: ModelRequest) -> _InternalModelResponse:
"""Execute model and return result or exception.
def _execute_model_sync(model_call: ModelCall) -> ModelResponse:
"""Execute model and return ModelResponse with messages and structured output.
This is the core model execution logic wrapped by wrap_model_call handlers.
"""
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
Handles model invocation, auto-detection of response format, and structured
output processing.
output = model_.invoke(messages)
return _InternalModelResponse(
result=output,
exception=None,
effective_response_format=effective_response_format,
)
except Exception as error: # noqa: BLE001
# Catch all exceptions from model invocation
return _InternalModelResponse(
result=None,
exception=error,
effective_response_format=None,
)
Args:
model_call: The model call parameters.
Returns:
ModelResponse with result (list of messages) and structured_response (if applicable).
Raises:
Exception: Any exception from model invocation or structured output processing.
"""
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(model_call)
messages = model_call.messages
if model_call.system_prompt:
messages = [SystemMessage(model_call.system_prompt), *messages]
output: AIMessage = model_.invoke(messages)
# Handle structured output with provider strategy
if isinstance(effective_response_format, ProviderStrategy):
if not output.tool_calls:
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
effective_response_format.schema_spec
)
structured_response = provider_strategy_binding.parse(output)
return ModelResponse(result=[output], structured_response=structured_response)
return ModelResponse(result=[output])
# Handle structured output with tool strategy
if (
isinstance(effective_response_format, ToolStrategy)
and isinstance(output, AIMessage)
and output.tool_calls
):
structured_tool_calls = [
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
]
if structured_tool_calls:
if len(structured_tool_calls) > 1:
# Handle multiple structured outputs error
tool_names = [tc["name"] for tc in structured_tool_calls]
multiple_outputs_error = MultipleStructuredOutputsError(tool_names)
should_retry, error_message = _handle_structured_output_error(
multiple_outputs_error, effective_response_format
)
if not should_retry:
raise multiple_outputs_error
# Add error messages and retry
tool_messages = [
ToolMessage(
content=error_message,
tool_call_id=tc["id"],
name=tc["name"],
)
for tc in structured_tool_calls
]
return ModelResponse(result=[output, *tool_messages])
# Handle single structured output
tool_call = structured_tool_calls[0]
try:
structured_tool_binding = structured_output_tools[tool_call["name"]]
structured_response = structured_tool_binding.parse(tool_call["args"])
tool_message_content = (
effective_response_format.tool_message_content
if effective_response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
return ModelResponse(
result=[
output,
ToolMessage(
content=tool_message_content,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
structured_response=structured_response,
)
except Exception as exc: # noqa: BLE001
validation_error = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
validation_error, effective_response_format
)
if not should_retry:
raise validation_error
return ModelResponse(
result=[
output,
ToolMessage(
content=error_message,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
]
)
return ModelResponse(result=[output])
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
# Create ModelCall with invocation parameters
model_call = ModelCall(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
# Create ModelRequest with model_call + state + runtime
request = ModelRequest(
model_call=model_call,
state=state,
runtime=runtime,
)
# Execute with or without handler
effective_response_format: Any = None
# Define base handler that executes the model
def base_handler(req: ModelRequest) -> AIMessage:
nonlocal effective_response_format
internal_response = _execute_model_sync(req)
if internal_response.exception is not None:
raise internal_response.exception
if internal_response.result is None:
msg = "Model execution succeeded but returned no result"
raise RuntimeError(msg)
effective_response_format = internal_response.effective_response_format
return internal_response.result
# Execute with or without middleware handlers
# Handler returns ModelResponse with messages and structured_response
if wrap_model_call_handler is None:
# No handlers - execute directly
output = base_handler(request)
response = _execute_model_sync(model_call)
else:
# Call composed handler with base handler
output = wrap_model_call_handler(request, base_handler)
return {
response = wrap_model_call_handler(request, _execute_model_sync)
# Build result dict with model call counts and messages
result: dict[str, Any] = {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output, effective_response_format),
"messages": response.result,
}
async def _execute_model_async(request: ModelRequest) -> _InternalModelResponse:
"""Execute model asynchronously and return result or exception.
# Add structured response if present
if response.structured_response is not None:
result["structured_response"] = response.structured_response
return result
async def _execute_model_async(model_call: ModelCall) -> ModelResponse:
"""Execute model asynchronously and return ModelResponse.
Returns ModelResponse with messages and structured output.
This is the core async model execution logic wrapped by wrap_model_call handlers.
"""
try:
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
Handles model invocation, auto-detection of response format, and structured
output processing.
output = await model_.ainvoke(messages)
return _InternalModelResponse(
result=output,
exception=None,
effective_response_format=effective_response_format,
)
except Exception as error: # noqa: BLE001
# Catch all exceptions from model invocation
return _InternalModelResponse(
result=None,
exception=error,
effective_response_format=None,
)
Args:
model_call: The model call parameters.
Returns:
ModelResponse with result (list of messages) and structured_response (if applicable).
Raises:
Exception: Any exception from model invocation or structured output processing.
"""
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(model_call)
messages = model_call.messages
if model_call.system_prompt:
messages = [SystemMessage(model_call.system_prompt), *messages]
output: AIMessage = await model_.ainvoke(messages)
# Handle structured output with provider strategy
if isinstance(effective_response_format, ProviderStrategy):
if not output.tool_calls:
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
effective_response_format.schema_spec
)
structured_response = provider_strategy_binding.parse(output)
return ModelResponse(result=[output], structured_response=structured_response)
return ModelResponse(result=[output])
# Handle structured output with tool strategy
if (
isinstance(effective_response_format, ToolStrategy)
and isinstance(output, AIMessage)
and output.tool_calls
):
structured_tool_calls = [
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
]
if structured_tool_calls:
if len(structured_tool_calls) > 1:
# Handle multiple structured outputs error
tool_names = [tc["name"] for tc in structured_tool_calls]
multiple_outputs_error = MultipleStructuredOutputsError(tool_names)
should_retry, error_message = _handle_structured_output_error(
multiple_outputs_error, effective_response_format
)
if not should_retry:
raise multiple_outputs_error
# Add error messages and retry
tool_messages = [
ToolMessage(
content=error_message,
tool_call_id=tc["id"],
name=tc["name"],
)
for tc in structured_tool_calls
]
return ModelResponse(result=[output, *tool_messages])
# Handle single structured output
tool_call = structured_tool_calls[0]
try:
structured_tool_binding = structured_output_tools[tool_call["name"]]
structured_response = structured_tool_binding.parse(tool_call["args"])
tool_message_content = (
effective_response_format.tool_message_content
if effective_response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
return ModelResponse(
result=[
output,
ToolMessage(
content=tool_message_content,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
structured_response=structured_response,
)
except Exception as exc: # noqa: BLE001
validation_error = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
validation_error, effective_response_format
)
if not should_retry:
raise validation_error
return ModelResponse(
result=[
output,
ToolMessage(
content=error_message,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
]
)
return ModelResponse(result=[output])
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
# Create ModelCall with invocation parameters
model_call = ModelCall(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
# Create ModelRequest with model_call + state + runtime
request = ModelRequest(
model_call=model_call,
state=state,
runtime=runtime,
)
# Execute with or without handler
effective_response_format: Any = None
# Define base async handler that executes the model
async def base_handler(req: ModelRequest) -> AIMessage:
nonlocal effective_response_format
internal_response = await _execute_model_async(req)
if internal_response.exception is not None:
raise internal_response.exception
if internal_response.result is None:
msg = "Model execution succeeded but returned no result"
raise RuntimeError(msg)
effective_response_format = internal_response.effective_response_format
return internal_response.result
# Execute with or without middleware handlers
# Handler returns ModelResponse with messages and structured_response
if awrap_model_call_handler is None:
# No async handlers - execute directly
output = await base_handler(request)
response = await _execute_model_async(model_call)
else:
# Call composed async handler with base handler
output = await awrap_model_call_handler(request, base_handler)
return {
response = await awrap_model_call_handler(request, _execute_model_async)
# Build result dict with model call counts and messages
result: dict[str, Any] = {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output, effective_response_format),
"messages": response.result,
}
# Add structured response if present
if response.structured_response is not None:
result["structured_response"] = response.structured_response
return result
# Use sync or async based on model capabilities
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))

View File

@@ -16,7 +16,9 @@ from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentState,
ModelCall,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
@@ -35,9 +37,11 @@ __all__ = [
"ContextEditingMiddleware",
"HumanInTheLoopMiddleware",
"LLMToolSelectorMiddleware",
"ModelCall",
"ModelCallLimitMiddleware",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",

View File

@@ -22,7 +22,12 @@ from langchain_core.messages import (
from langchain_core.messages.utils import count_tokens_approximately
from typing_extensions import Protocol
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCall,
ModelRequest,
ModelResponse,
)
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
@@ -209,11 +214,11 @@ class ContextEditingMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Apply context edits before invoking the model via handler."""
if not request.messages:
return handler(request)
if not request.model_call.messages:
return handler(request.model_call)
if self.token_count_method == "approximate": # noqa: S105
@@ -221,18 +226,20 @@ class ContextEditingMiddleware(AgentMiddleware):
return count_tokens_approximately(messages)
else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
[SystemMessage(content=request.model_call.system_prompt)]
if request.model_call.system_prompt
else []
)
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
system_msg + list(messages), request.tools
return request.model_call.model.get_num_tokens_from_messages(
system_msg + list(messages), request.model_call.tools
)
for edit in self.edits:
edit.apply(request.messages, count_tokens=count_tokens)
edit.apply(request.model_call.messages, count_tokens=count_tokens)
return handler(request)
return handler(request.model_call)
__all__ = [

View File

@@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCall,
ModelRequest,
ModelResponse,
)
from langchain.chat_models import init_chat_model
@@ -14,7 +16,6 @@ if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
class ModelFallbackMiddleware(AgentMiddleware):
@@ -68,18 +69,16 @@ class ModelFallbackMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Try fallback models in sequence on errors.
Args:
request: Initial model request.
state: Current agent state.
runtime: LangGraph runtime.
handler: Callback to execute the model.
request: Full model request including state and runtime.
handler: Callback to execute the model call.
Returns:
AIMessage from successful model call.
ModelResponse from successful model call.
Raises:
Exception: If all models fail, re-raises last exception.
@@ -87,15 +86,15 @@ class ModelFallbackMiddleware(AgentMiddleware):
# Try primary model first
last_exception: Exception
try:
return handler(request)
return handler(request.model_call)
except Exception as e: # noqa: BLE001
last_exception = e
# Try fallback models
for fallback_model in self.models:
request.model = fallback_model
request.model_call.model = fallback_model
try:
return handler(request)
return handler(request.model_call)
except Exception as e: # noqa: BLE001
last_exception = e
continue

View File

@@ -8,12 +8,18 @@ from typing import TYPE_CHECKING, Annotated, Literal
if TYPE_CHECKING:
from collections.abc import Callable
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCall,
ModelRequest,
ModelResponse,
)
from langchain.tools import InjectedToolCallId
@@ -189,12 +195,12 @@ class PlanningMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Update the system prompt to include the todo system prompt."""
request.system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
request.model_call.system_prompt = (
request.model_call.system_prompt + "\n\n" + self.system_prompt
if request.model_call.system_prompt
else self.system_prompt
)
return handler(request)
return handler(request.model_call)

View File

@@ -4,9 +4,12 @@ from collections.abc import Callable
from typing import Literal
from warnings import warn
from langchain_core.messages import AIMessage
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCall,
ModelRequest,
ModelResponse,
)
class AnthropicPromptCachingMiddleware(AgentMiddleware):
@@ -45,8 +48,8 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
@@ -61,10 +64,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
"Anthropic models. "
"Please install langchain-anthropic."
)
elif not isinstance(request.model, ChatAnthropic):
elif not isinstance(request.model_call.model, ChatAnthropic):
msg = (
"AnthropicPromptCachingMiddleware caching middleware only supports "
f"Anthropic models, not instances of {type(request.model)}"
f"Anthropic models, not instances of {type(request.model_call.model)}"
)
if msg is not None:
@@ -73,14 +76,16 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
return handler(request)
return handler(request.model_call)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
len(request.model_call.messages) + 1
if request.model_call.system_prompt
else len(request.model_call.messages)
)
if messages_count < self.min_messages_to_cache:
return handler(request)
return handler(request.model_call)
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
request.model_call.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return handler(request)
return handler(request.model_call)

View File

@@ -12,11 +12,16 @@ if TYPE_CHECKING:
from langchain.tools import BaseTool
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import HumanMessage
from pydantic import Field, TypeAdapter
from typing_extensions import TypedDict
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCall,
ModelRequest,
ModelResponse,
)
from langchain.chat_models.base import init_chat_model
logger = logging.getLogger(__name__)
@@ -142,11 +147,11 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
SelectionRequest with prepared inputs, or None if no selection is needed.
"""
# If no tools available, return None
if not request.tools or len(request.tools) == 0:
if not request.model_call.tools or len(request.model_call.tools) == 0:
return None
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
base_tools = [tool for tool in request.model_call.tools if not isinstance(tool, dict)]
# Validate that always_include tools exist
if self.always_include:
@@ -180,7 +185,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
# Get the last user message from the conversation history
last_user_message: HumanMessage
for message in reversed(request.messages):
for message in reversed(request.model_call.messages):
if isinstance(message, HumanMessage):
last_user_message = message
break
@@ -188,7 +193,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
msg = "No user message found in request messages"
raise AssertionError(msg)
model = self.model or request.model
model = self.model or request.model_call.model
valid_tool_names = [tool.name for tool in available_tools]
return _SelectionRequest(
@@ -205,8 +210,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
available_tools: list[BaseTool],
valid_tool_names: list[str],
request: ModelRequest,
) -> ModelRequest:
"""Process the selection response and return filtered ModelRequest."""
) -> None:
"""Process the selection response and update ModelRequest with filtered tools."""
selected_tool_names: list[str] = []
invalid_tool_selections = []
@@ -231,26 +236,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
]
always_included_tools: list[BaseTool] = [
tool
for tool in request.tools
for tool in request.model_call.tools
if not isinstance(tool, dict) and tool.name in self.always_include
]
selected_tools.extend(always_included_tools)
# Also preserve any provider-specific tool dicts from the original request
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
provider_tools = [tool for tool in request.model_call.tools if isinstance(tool, dict)]
request.tools = [*selected_tools, *provider_tools]
return request
request.model_call.tools = [*selected_tools, *provider_tools]
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Filter tools based on LLM selection before invoking the model via handler."""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return handler(request)
return handler(request.model_call)
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -268,20 +272,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg)
modified_request = self._process_selection_response(
self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)
return handler(modified_request)
return handler(request.model_call)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Filter tools based on LLM selection before invoking the model via handler."""
selection_request = self._prepare_selection_request(request)
if selection_request is None:
return await handler(request)
return await handler(request.model_call)
# Create dynamic response model with Literal enum of available tool names
type_adapter = _create_tool_selection_response(selection_request.available_tools)
@@ -299,7 +303,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
if not isinstance(response, dict):
msg = f"Expected dict response, got {type(response)}"
raise AssertionError(msg)
modified_request = self._process_selection_response(
self._process_selection_response(
response, selection_request.available_tools, selection_request.valid_tool_names, request
)
return await handler(modified_request)
return await handler(request.model_call)

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from collections.abc import Callable
from dataclasses import dataclass, field
from inspect import iscoroutinefunction
from typing import (
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from langchain.tools.tool_node import ToolCallRequest
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002
from langchain_core.messages import AnyMessage, BaseMessage, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.graph.message import add_messages
@@ -41,7 +41,9 @@ __all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"ModelCall",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
"PublicAgentState",
"after_agent",
@@ -60,8 +62,11 @@ ResponseT = TypeVar("ResponseT")
@dataclass
class ModelRequest:
"""Model request information for the agent."""
class ModelCall:
"""Model invocation parameters for a single model call.
Contains only the parameters needed to invoke the model, without agent context.
"""
model: BaseChatModel
system_prompt: str | None
@@ -69,9 +74,34 @@ class ModelRequest:
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
model_settings: dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelRequest:
"""Full request context for model invocation including agent state.
Combines model invocation parameters with agent state and runtime context.
"""
model_call: ModelCall
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelResponse:
"""Response from model execution including messages and optional structured output.
The result will usually contain a single AIMessage, but may include
an additional ToolMessage if the model used a tool for structured output.
"""
result: list[BaseMessage]
"""List of messages from model execution."""
structured_response: Any = None
"""Parsed structured output if response_format was specified, None otherwise."""
@dataclass
@@ -167,23 +197,23 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Intercept and control model execution via handler callback.
The handler callback executes the model request and returns an AIMessage.
Middleware can call the handler multiple times for retry logic, skip calling
it to short-circuit, or modify the request/response. Multiple middleware
compose with first in list as outermost layer.
The handler callback executes the model call and returns a ModelResponse containing
messages and optional structured_response. Middleware can call the handler multiple
times for retry logic, skip calling it to short-circuit, or modify the request/response.
Multiple middleware compose with first in list as outermost layer.
Args:
request: Model request to execute (includes state and runtime).
handler: Callback that executes the model request and returns AIMessage.
Call this to execute the model. Can be called multiple times
for retry logic. Can skip calling it to short-circuit.
request: Full model request including state and runtime context.
handler: Callback that executes the model call and returns ModelResponse.
Pass request.model_call to execute the model. Can be called
multiple times for retry logic. Can skip calling it to short-circuit.
Returns:
Final AIMessage to use (from handler or custom).
Final ModelResponse to use (from handler or custom).
Examples:
Retry on error:
@@ -191,36 +221,40 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_model_call(self, request, handler):
for attempt in range(3):
try:
return handler(request)
return handler(request.model_call)
except Exception:
if attempt == 2:
raise
```
Rewrite response:
Modify messages:
```python
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=f"[{result.content}]")
response = handler(request.model_call)
# Modify first message (AIMessage)
ai_msg = response.result[0]
modified = AIMessage(content=f"[{ai_msg.content}]")
return ModelResponse(
result=[modified, *response.result[1:]],
structured_response=response.structured_response,
)
```
Error to fallback:
```python
def wrap_model_call(self, request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception:
return AIMessage(content="Service unavailable")
return ModelResponse(result=[AIMessage(content="Service unavailable")])
```
Cache/short-circuit:
Modify model settings:
```python
def wrap_model_call(self, request, handler):
if cached := get_cache(request):
return cached # Short-circuit with cached result
result = handler(request)
save_cache(request, result)
return result
# Modify the model call parameters
request.model_call.model_settings["temperature"] = 0.7
return handler(request.model_call)
```
"""
raise NotImplementedError
@@ -228,16 +262,17 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
"""Async version of wrap_model_call.
Args:
request: Model request to execute (includes state and runtime).
handler: Async callback that executes the model request.
request: Full model request including state and runtime context.
handler: Async callback that executes the model call and returns ModelResponse.
Pass request.model_call to execute the model.
Returns:
Final AIMessage to use (from handler or custom).
Final ModelResponse to use (from handler or custom).
Examples:
Retry on error:
@@ -245,7 +280,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_model_call(self, request, handler):
for attempt in range(3):
try:
return await handler(request)
return await handler(request.model_call)
except Exception:
if attempt == 2:
raise
@@ -337,14 +372,14 @@ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable for model call interception with handler callback.
Receives handler callback to execute model and returns final AIMessage.
Receives handler callback to execute model and returns final ModelResponse.
"""
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
"""Intercept model execution via handler callback."""
...
@@ -1037,11 +1072,11 @@ def dynamic_prompt(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
prompt = await func(request) # type: ignore[misc]
request.system_prompt = prompt
return await handler(request)
request.model_call.system_prompt = prompt
return await handler(request.model_call)
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1058,11 +1093,11 @@ def dynamic_prompt(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
prompt = cast("str", func(request))
request.system_prompt = prompt
return handler(request)
request.model_call.system_prompt = prompt
return handler(request.model_call)
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
@@ -1176,8 +1211,8 @@ def wrap_model_call(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
return await func(request, handler) # type: ignore[misc, arg-type]
middleware_name = name or cast(
@@ -1197,8 +1232,8 @@ def wrap_model_call(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
return func(request, handler)
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))

View File

@@ -115,7 +115,7 @@ class TestLLMToolSelectorBasic:
"""Middleware to select relevant tools based on state/context."""
# Select a small, relevant subset of tools based on state/context
model_requests.append(request)
return handler(request)
return handler(request.model_call)
tool_selection_model = FakeModel(
messages=cycle(
@@ -161,7 +161,9 @@ class TestLLMToolSelectorBasic:
assert isinstance(response["messages"][-1], AIMessage)
for request in model_requests:
selected_tool_names = [tool.name for tool in request.tools] if request.tools else []
selected_tool_names = (
[tool.name for tool in request.model_call.tools] if request.model_call.tools else []
)
assert selected_tool_names == ["get_weather", "calculate"]
async def test_async_basic_selection(self) -> None:
@@ -218,7 +220,7 @@ class TestMaxToolsLimiting:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector model tries to select 4 tools
tool_selection_model = FakeModel(
@@ -261,8 +263,8 @@ class TestMaxToolsLimiting:
# Verify only 2 tools were passed to the main model
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 2
tool_names = [tool.name for tool in request.tools]
assert len(request.model_call.tools) == 2
tool_names = [tool.name for tool in request.model_call.tools]
# Should be first 2 from the selection
assert tool_names == ["get_weather", "search_web"]
@@ -273,7 +275,7 @@ class TestMaxToolsLimiting:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
tool_selection_model = FakeModel(
messages=cycle(
@@ -315,8 +317,8 @@ class TestMaxToolsLimiting:
# All 4 selected tools should be present
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert len(request.model_call.tools) == 4
tool_names = [tool.name for tool in request.model_call.tools]
assert set(tool_names) == {
"get_weather",
"search_web",
@@ -335,7 +337,7 @@ class TestAlwaysInclude:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector picks only search_web
tool_selection_model = FakeModel(
@@ -373,7 +375,7 @@ class TestAlwaysInclude:
# Both selected and always_include tools should be present
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
tool_names = [tool.name for tool in request.model_call.tools]
assert "search_web" in tool_names
assert "send_email" in tool_names
assert len(tool_names) == 2
@@ -385,7 +387,7 @@ class TestAlwaysInclude:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector picks 2 tools
tool_selection_model = FakeModel(
@@ -425,8 +427,8 @@ class TestAlwaysInclude:
# Should have 2 selected + 2 always_include = 4 total
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert len(request.model_call.tools) == 4
tool_names = [tool.name for tool in request.model_call.tools]
assert "get_weather" in tool_names
assert "search_web" in tool_names
assert "send_email" in tool_names
@@ -439,7 +441,7 @@ class TestAlwaysInclude:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector picks 1 tool
tool_selection_model = FakeModel(
@@ -478,8 +480,8 @@ class TestAlwaysInclude:
# Should have 1 selected + 3 always_include = 4 total
assert len(model_requests) > 0
for request in model_requests:
assert len(request.tools) == 4
tool_names = [tool.name for tool in request.tools]
assert len(request.model_call.tools) == 4
tool_names = [tool.name for tool in request.model_call.tools]
assert "get_weather" in tool_names
assert "send_email" in tool_names
assert "calculate" in tool_names
@@ -496,7 +498,7 @@ class TestDuplicateAndInvalidTools:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector returns duplicates
tool_selection_model = FakeModel(
@@ -538,7 +540,7 @@ class TestDuplicateAndInvalidTools:
# Duplicates should be removed
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
tool_names = [tool.name for tool in request.model_call.tools]
assert tool_names == ["get_weather", "search_web"]
assert len(tool_names) == 2
@@ -549,7 +551,7 @@ class TestDuplicateAndInvalidTools:
@wrap_model_call
def trace_model_requests(request, handler):
model_requests.append(request)
return handler(request)
return handler(request.model_call)
# Selector returns duplicates but max_tools=2
tool_selection_model = FakeModel(
@@ -592,7 +594,7 @@ class TestDuplicateAndInvalidTools:
# Should deduplicate and respect max_tools
assert len(model_requests) > 0
for request in model_requests:
tool_names = [tool.name for tool in request.tools]
tool_names = [tool.name for tool in request.model_call.tools]
assert len(tool_names) == 2
assert "get_weather" in tool_names
assert "search_web" in tool_names

View File

@@ -8,7 +8,9 @@ from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCall,
ModelRequest,
ModelResponse,
wrap_model_call,
)
@@ -21,7 +23,7 @@ class TestOnModelCallDecorator:
@wrap_model_call
def passthrough_middleware(request, handler):
return handler(request)
return handler(request.model_call)
# Should return an AgentMiddleware instance
assert isinstance(passthrough_middleware, AgentMiddleware)
@@ -39,7 +41,7 @@ class TestOnModelCallDecorator:
@wrap_model_call(name="CustomMiddleware")
def my_middleware(request, handler):
return handler(request)
return handler(request.model_call)
assert isinstance(my_middleware, AgentMiddleware)
assert my_middleware.__class__.__name__ == "CustomMiddleware"
@@ -58,10 +60,10 @@ class TestOnModelCallDecorator:
@wrap_model_call
def retry_once(request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception:
# Retry once
return handler(request)
return handler(request.model_call)
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
agent = create_agent(model=model, middleware=[retry_once])
@@ -76,8 +78,8 @@ class TestOnModelCallDecorator:
@wrap_model_call
def uppercase_responses(request, handler):
result = handler(request)
return AIMessage(content=result.content.upper())
result = handler(request.model_call)
return ModelResponse(result=[AIMessage(content=result.result[0].content.upper())])
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
agent = create_agent(model=model, middleware=[uppercase_responses])
@@ -96,9 +98,9 @@ class TestOnModelCallDecorator:
@wrap_model_call
def error_to_fallback(request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception:
return AIMessage(content="Fallback response")
return ModelResponse(result=[AIMessage(content="Fallback response")])
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[error_to_fallback])
@@ -114,7 +116,7 @@ class TestOnModelCallDecorator:
@wrap_model_call
def log_state(request, handler):
state_values.append(request.state.get("messages"))
return handler(request)
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[log_state])
@@ -133,14 +135,14 @@ class TestOnModelCallDecorator:
@wrap_model_call
def outer_middleware(request, handler):
execution_order.append("outer-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("outer-after")
return result
@wrap_model_call
def inner_middleware(request, handler):
execution_order.append("inner-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("inner-after")
return result
@@ -166,7 +168,7 @@ class TestOnModelCallDecorator:
@wrap_model_call(state_schema=CustomState)
def middleware_with_schema(request, handler):
return handler(request)
return handler(request.model_call)
assert isinstance(middleware_with_schema, AgentMiddleware)
# Custom state schema should be set
@@ -183,7 +185,7 @@ class TestOnModelCallDecorator:
@wrap_model_call(tools=[test_tool])
def middleware_with_tools(request, handler):
return handler(request)
return handler(request.model_call)
assert isinstance(middleware_with_tools, AgentMiddleware)
assert len(middleware_with_tools.tools) == 1
@@ -195,12 +197,12 @@ class TestOnModelCallDecorator:
# Without parentheses
@wrap_model_call
def middleware_no_parens(request, handler):
return handler(request)
return handler(request.model_call)
# With parentheses
@wrap_model_call()
def middleware_with_parens(request, handler):
return handler(request)
return handler(request.model_call)
assert isinstance(middleware_no_parens, AgentMiddleware)
assert isinstance(middleware_with_parens, AgentMiddleware)
@@ -210,7 +212,7 @@ class TestOnModelCallDecorator:
@wrap_model_call
def my_custom_middleware(request, handler):
return handler(request)
return handler(request.model_call)
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
@@ -221,14 +223,14 @@ class TestOnModelCallDecorator:
@wrap_model_call
def decorated_middleware(request, handler):
execution_order.append("decorated-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("decorated-after")
return result
class ClassMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("class-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("class-after")
return result
@@ -267,7 +269,7 @@ class TestOnModelCallDecorator:
for attempt in range(max_retries):
attempts.append(attempt + 1)
try:
return handler(request)
return handler(request.model_call)
except Exception as e:
last_exception = e
# On error, continue to next attempt
@@ -291,7 +293,7 @@ class TestOnModelCallDecorator:
@wrap_model_call
async def logging_middleware(request, handler):
call_log.append("before")
result = await handler(request)
result = await handler(request.model_call)
call_log.append("after")
return result
@@ -310,18 +312,9 @@ class TestOnModelCallDecorator:
@wrap_model_call
def add_system_prompt(request, handler):
# Modify request to add system prompt
modified_request = ModelRequest(
messages=request.messages,
model=request.model,
system_prompt="You are a helpful assistant",
tool_choice=request.tool_choice,
tools=request.tools,
response_format=request.response_format,
state={},
runtime=None,
)
modified_prompts.append(modified_request.system_prompt)
return handler(modified_request)
request.model_call.system_prompt = "You are a helpful assistant"
modified_prompts.append(request.model_call.system_prompt)
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[add_system_prompt])
@@ -335,13 +328,13 @@ class TestOnModelCallDecorator:
@wrap_model_call
def multi_transform(request, handler):
result = handler(request)
result = handler(request.model_call)
# First transformation: uppercase
content = result.content.upper()
content = result.result[0].content.upper()
# Second transformation: add prefix and suffix
content = f"[START] {content} [END]"
return AIMessage(content=content)
return ModelResponse(result=[AIMessage(content=content)])
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
agent = create_agent(model=model, middleware=[multi_transform])

View File

@@ -9,6 +9,7 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
ModelResponse,
)
@@ -20,7 +21,7 @@ class TestBasicOnModelCall:
class PassthroughMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
return handler(request)
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
@@ -37,7 +38,7 @@ class TestBasicOnModelCall:
class LoggingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
call_log.append("before")
result = handler(request)
result = handler(request.model_call)
call_log.append("after")
return result
@@ -59,7 +60,7 @@ class TestBasicOnModelCall:
def wrap_model_call(self, request, handler):
self.call_count += 1
return handler(request)
return handler(request.model_call)
counter = CountingMiddleware()
model = GenericFakeChatModel(messages=iter([AIMessage(content="Reply")]))
@@ -91,11 +92,11 @@ class TestRetryMiddleware:
def wrap_model_call(self, request, handler):
try:
result = handler(request)
result = handler(request.model_call)
return result
except Exception:
self.retry_count += 1
result = handler(request)
result = handler(request.model_call)
return result
retry_middleware = RetryOnceMiddleware()
@@ -125,7 +126,7 @@ class TestRetryMiddleware:
for attempt in range(self.max_retries):
self.attempts.append(attempt + 1)
try:
result = handler(request)
result = handler(request.model_call)
return result
except Exception as e:
last_exception = e
@@ -152,8 +153,9 @@ class TestResponseRewriting:
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=result.content.upper())
result = handler(request.model_call)
ai_msg = result.result[0]
return ModelResponse(result=[AIMessage(content=ai_msg.content.upper())])
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
agent = create_agent(model=model, middleware=[UppercaseMiddleware()])
@@ -171,8 +173,9 @@ class TestResponseRewriting:
self.prefix = prefix
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=f"{self.prefix}{result.content}")
result = handler(request.model_call)
ai_msg = result.result[0]
return ModelResponse(result=[AIMessage(content=f"{self.prefix}{ai_msg.content}")])
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[PrefixMiddleware(prefix="[BOT]: ")])
@@ -195,10 +198,9 @@ class TestErrorHandling:
class ErrorToSuccessMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception:
fallback = AIMessage(content="Error handled gracefully")
return fallback
return ModelResponse(result=[AIMessage(content="Error handled gracefully")])
model = AlwaysFailModel(messages=iter([]))
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
@@ -218,10 +220,11 @@ class TestErrorHandling:
class SelectiveErrorMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
try:
return handler(request)
return handler(request.model_call)
except ConnectionError:
fallback = AIMessage(content="Network issue, try again later")
return fallback
return ModelResponse(
result=[AIMessage(content="Network issue, try again later")]
)
model = SpecificErrorModel(messages=iter([]))
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
@@ -238,13 +241,12 @@ class TestErrorHandling:
def wrap_model_call(self, request, handler):
try:
call_log.append("before-yield")
result = handler(request)
result = handler(request.model_call)
call_log.append("after-yield-success")
return result
except Exception:
call_log.append("caught-error")
fallback = AIMessage(content="Recovered from error")
return fallback
return ModelResponse(result=[AIMessage(content="Recovered from error")])
# Test 1: Success path
call_log.clear()
@@ -281,14 +283,18 @@ class TestShortCircuit:
class CachingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
# Simple cache key based on last message
cache_key = str(request.messages[-1].content) if request.messages else ""
cache_key = (
str(request.model_call.messages[-1].content)
if request.model_call.messages
else ""
)
if cache_key in cache:
# Short-circuit with cached result
return cache[cache_key]
else:
# Execute and cache
result = handler(request)
result = handler(request.model_call)
cache[cache_key] = result
return result
@@ -337,19 +343,9 @@ class TestRequestModification:
def wrap_model_call(self, request, handler):
# Modify request to add system prompt
modified_request = ModelRequest(
model=request.model,
system_prompt=self.system_prompt,
messages=request.messages,
tools=request.tools,
tool_choice=request.tool_choice,
response_format=request.response_format,
model_settings=request.model_settings,
state=request.state,
runtime=request.runtime,
)
received_requests.append(modified_request)
return handler(modified_request)
request.model_call.system_prompt = self.system_prompt
received_requests.append(request)
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
@@ -360,7 +356,7 @@ class TestRequestModification:
result = agent.invoke({"messages": [HumanMessage("Test")]})
assert len(received_requests) == 1
assert received_requests[0].system_prompt == "You are a helpful assistant."
assert received_requests[0].model_call.system_prompt == "You are a helpful assistant."
assert result["messages"][1].content == "Response"
@@ -380,7 +376,7 @@ class TestStateAndRuntime:
"messages_count": len(request.state.get("messages", [])),
}
)
return handler(request)
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[StateAwareMiddleware()])
@@ -399,7 +395,7 @@ class TestStateAndRuntime:
max_retries = 2
for attempt in range(max_retries):
try:
return handler(request)
return handler(request.model_call)
break # Success
except Exception:
if attempt == max_retries - 1:
@@ -433,14 +429,14 @@ class TestMiddlewareComposition:
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("outer-before")
response = handler(request)
response = handler(request.model_call)
execution_order.append("outer-after")
return response
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("inner-before")
response = handler(request)
response = handler(request.model_call)
execution_order.append("inner-after")
return response
@@ -472,7 +468,7 @@ class TestMiddlewareComposition:
class LoggingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
log.append("logging-before")
result = handler(request)
result = handler(request.model_call)
log.append("logging-after")
return result
@@ -480,12 +476,12 @@ class TestMiddlewareComposition:
def wrap_model_call(self, request, handler):
log.append("retry-before")
try:
result = handler(request)
result = handler(request.model_call)
log.append("retry-after")
return result
except Exception:
log.append("retry-retrying")
result = handler(request)
result = handler(request.model_call)
log.append("retry-after")
return result
@@ -510,13 +506,15 @@ class TestMiddlewareComposition:
class PrefixMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=f"[PREFIX] {result.content}")
result = handler(request.model_call)
ai_msg = result.result[0]
return ModelResponse(result=[AIMessage(content=f"[PREFIX] {ai_msg.content}")])
class SuffixMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=f"{result.content} [SUFFIX]")
result = handler(request.model_call)
ai_msg = result.result[0]
return ModelResponse(result=[AIMessage(content=f"{ai_msg.content} [SUFFIX]")])
model = GenericFakeChatModel(messages=iter([AIMessage(content="Middle")]))
# Prefix is outer, Suffix is inner
@@ -542,16 +540,17 @@ class TestMiddlewareComposition:
class RetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
try:
result = handler(request)
result = handler(request.model_call)
return result
except Exception:
result = handler(request)
result = handler(request.model_call)
return result
class UppercaseMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
result = handler(request)
return AIMessage(content=result.content.upper())
result = handler(request.model_call)
ai_msg = result.result[0]
return ModelResponse(result=[AIMessage(content=ai_msg.content.upper())])
model = FailOnceThenSucceed(messages=iter([AIMessage(content="success")]))
# Retry outer, Uppercase inner
@@ -569,21 +568,21 @@ class TestMiddlewareComposition:
class FirstMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("first-before")
response = handler(request)
response = handler(request.model_call)
execution_order.append("first-after")
return response
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("second-before")
response = handler(request)
response = handler(request.model_call)
execution_order.append("second-after")
return response
class ThirdMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("third-before")
response = handler(request)
response = handler(request.model_call)
execution_order.append("third-after")
return response
@@ -613,7 +612,7 @@ class TestMiddlewareComposition:
class OuterMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("outer-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("outer-after")
return result
@@ -621,16 +620,16 @@ class TestMiddlewareComposition:
def wrap_model_call(self, request, handler):
execution_order.append("middle-before")
# Always retry once (call handler twice)
result = handler(request)
result = handler(request.model_call)
execution_order.append("middle-retry")
result = handler(request)
result = handler(request.model_call)
execution_order.append("middle-after")
return result
class InnerMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
execution_order.append("inner-before")
result = handler(request)
result = handler(request.model_call)
execution_order.append("inner-after")
return result
@@ -675,7 +674,7 @@ class TestAsyncOnModelCall:
class LoggingMiddleware(AgentMiddleware):
async def awrap_model_call(self, request, handler):
log.append("before")
result = await handler(request)
result = await handler(request.model_call)
log.append("after")
return result
@@ -702,9 +701,9 @@ class TestAsyncOnModelCall:
class RetryMiddleware(AgentMiddleware):
async def awrap_model_call(self, request, handler):
try:
return await handler(request)
return await handler(request.model_call)
except Exception:
return await handler(request)
return await handler(request.model_call)
model = AsyncFailOnceThenSucceed(messages=iter([AIMessage(content="Async success")]))
agent = create_agent(model=model, middleware=[RetryMiddleware()])
@@ -725,9 +724,8 @@ class TestEdgeCases:
class RequestModifyingMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
# Add a system message to the request
modified_request = request
modified_messages.append(len(modified_request.messages))
return handler(modified_request)
modified_messages.append(len(request.model_call.messages))
return handler(request.model_call)
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, middleware=[RequestModifyingMiddleware()])
@@ -744,11 +742,11 @@ class TestEdgeCases:
def wrap_model_call(self, request, handler):
attempts.append("first-attempt")
try:
result = handler(request)
result = handler(request.model_call)
return result
except Exception:
attempts.append("retry-attempt")
result = handler(request)
result = handler(request.model_call)
return result
call_count = {"value": 0}

View File

@@ -8,7 +8,7 @@ from langchain.agents.middleware.context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
)
from langchain.agents.middleware.types import AgentState, ModelRequest
from langchain.agents.middleware.types import AgentState, ModelCall, ModelRequest, ModelResponse
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
@@ -55,16 +55,19 @@ def _make_state_and_request(
model = _TokenCountingChatModel()
conversation = list(messages)
state = cast(AgentState, {"messages": conversation})
request = ModelRequest(
model_call = ModelCall(
model=model,
system_prompt=system_prompt,
messages=conversation,
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
request = ModelRequest(
model_call=model_call,
state=state,
runtime=_fake_runtime(),
model_settings={},
)
return state, request
@@ -82,16 +85,16 @@ def test_no_edit_when_below_trigger() -> None:
edits=[ClearToolUsesEdit(trigger=50)],
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call which modifies the request
middleware.wrap_model_call(request, mock_handler)
# The request should have been modified in place
assert request.messages[0].content == ""
assert request.messages[1].content == "12345"
assert state["messages"] == request.messages
assert request.model_call.messages[0].content == ""
assert request.model_call.messages[1].content == "12345"
assert state["messages"] == request.model_call.messages
def test_clear_tool_outputs_and_inputs() -> None:
@@ -115,14 +118,14 @@ def test_clear_tool_outputs_and_inputs() -> None:
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call which modifies the request
middleware.wrap_model_call(request, mock_handler)
cleared_ai = request.messages[0]
cleared_tool = request.messages[1]
cleared_ai = request.model_call.messages[0]
cleared_tool = request.model_call.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared output]"
@@ -134,7 +137,7 @@ def test_clear_tool_outputs_and_inputs() -> None:
assert context_meta is not None
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
assert state["messages"] == request.messages
assert state["messages"] == request.model_call.messages
def test_respects_keep_last_tool_results() -> None:
@@ -167,21 +170,21 @@ def test_respects_keep_last_tool_results() -> None:
token_count_method="model",
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call which modifies the request
middleware.wrap_model_call(request, mock_handler)
cleared_messages = [
msg
for msg in request.messages
for msg in request.model_call.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
assert len(cleared_messages) == 2
assert isinstance(request.messages[-1], ToolMessage)
assert request.messages[-1].content != "[cleared]"
assert isinstance(request.model_call.messages[-1], ToolMessage)
assert request.model_call.messages[-1].content != "[cleared]"
def test_exclude_tools_prevents_clearing() -> None:
@@ -215,14 +218,14 @@ def test_exclude_tools_prevents_clearing() -> None:
],
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call which modifies the request
middleware.wrap_model_call(request, mock_handler)
search_tool = request.messages[1]
calc_tool = request.messages[3]
search_tool = request.model_call.messages[1]
calc_tool = request.model_call.messages[3]
assert isinstance(search_tool, ToolMessage)
assert search_tool.content == "search-results" * 20

View File

@@ -4,7 +4,7 @@ import pytest
from langchain_core.messages import AIMessage
from langchain.agents.factory import _chain_model_call_handlers
from langchain.agents.middleware.types import ModelRequest
from langchain.agents.middleware.types import ModelCall, ModelRequest, ModelResponse
from typing import cast
from langgraph.runtime import Runtime
@@ -13,18 +13,34 @@ from langgraph.runtime import Runtime
def create_test_request(**kwargs):
"""Helper to create a ModelRequest with sensible defaults."""
defaults = {
model_call_defaults = {
"messages": [],
"model": None,
"system_prompt": None,
"tool_choice": None,
"tools": [],
"response_format": None,
"model_settings": {},
}
# Extract model_call fields from kwargs
model_call_kwargs = {
k: kwargs.pop(k, v)
for k, v in model_call_defaults.items()
if k in kwargs or k in model_call_defaults
}
# Create ModelCall
model_call = ModelCall(**model_call_kwargs)
# Create ModelRequest with remaining kwargs
request_defaults = {
"model_call": model_call,
"state": {},
"runtime": cast(Runtime, object()),
}
defaults.update(kwargs)
return ModelRequest(**defaults)
request_defaults.update(kwargs)
return ModelRequest(**request_defaults)
class TestChainModelCallHandlers:
@@ -65,7 +81,7 @@ class TestChainModelCallHandlers:
# Execute the composed handler
def mock_base_handler(req):
return AIMessage(content="test")
return ModelResponse(result=[AIMessage(content="test")])
result = composed(create_test_request(), mock_base_handler)
@@ -75,7 +91,7 @@ class TestChainModelCallHandlers:
"inner-after",
"outer-after",
]
assert result.content == "test"
assert result.result[0].content == "test"
def test_three_handlers_composition(self) -> None:
"""Test composition of three handlers."""
@@ -103,7 +119,7 @@ class TestChainModelCallHandlers:
assert composed is not None
def mock_base_handler(req):
return AIMessage(content="test")
return ModelResponse(result=[AIMessage(content="test")])
result = composed(create_test_request(), mock_base_handler)
@@ -116,7 +132,7 @@ class TestChainModelCallHandlers:
"second-after",
"first-after",
]
assert result.content == "test"
assert result.result[0].content == "test"
def test_inner_handler_retry(self) -> None:
"""Test inner handler retrying before outer sees response."""
@@ -144,12 +160,12 @@ class TestChainModelCallHandlers:
call_count["value"] += 1
if call_count["value"] < 3:
raise ValueError("fail")
return AIMessage(content="success")
return ModelResponse(result=[AIMessage(content="success")])
result = composed(create_test_request(), mock_base_handler)
assert inner_attempts == [0, 1, 2]
assert result.content == "success"
assert result.result[0].content == "success"
def test_error_to_success_conversion(self) -> None:
"""Test handler converting error to success response."""
@@ -158,7 +174,7 @@ class TestChainModelCallHandlers:
try:
return handler(request)
except Exception:
return AIMessage(content="Fallback response")
return ModelResponse(result=[AIMessage(content="Fallback response")])
def inner_passthrough(request, handler):
return handler(request)
@@ -171,32 +187,32 @@ class TestChainModelCallHandlers:
result = composed(create_test_request(), mock_base_handler)
assert result.content == "Fallback response"
assert result.result[0].content == "Fallback response"
def test_request_modification(self) -> None:
"""Test handlers modifying the request."""
requests_seen = []
def outer_add_context(request, handler):
modified_request = create_test_request(
messages=[*request.messages], system_prompt="Added by outer"
)
return handler(modified_request)
# Modify the model_call
request.model_call.system_prompt = "Added by outer"
return handler(request.model_call)
def inner_track_request(request, handler):
requests_seen.append(request.system_prompt)
return handler(request)
# Inner handler receives ModelRequest due to composition
requests_seen.append(request.model_call.system_prompt)
return handler(request.model_call)
composed = _chain_model_call_handlers([outer_add_context, inner_track_request])
assert composed is not None
def mock_base_handler(req):
return AIMessage(content="response")
return ModelResponse(result=[AIMessage(content="response")])
result = composed(create_test_request(), mock_base_handler)
assert requests_seen == ["Added by outer"]
assert result.content == "response"
assert result.result[0].content == "response"
def test_composition_preserves_state_and_runtime(self) -> None:
"""Test that state and runtime are passed through composition."""
@@ -220,7 +236,7 @@ class TestChainModelCallHandlers:
test_runtime = {"test": "runtime"}
def mock_base_handler(req):
return AIMessage(content="test")
return ModelResponse(result=[AIMessage(content="test")])
# Create request with state and runtime
test_request = create_test_request()
@@ -231,7 +247,7 @@ class TestChainModelCallHandlers:
# Both handlers should see same state and runtime
assert state_values == [("outer", test_state), ("inner", test_state)]
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
assert result.content == "test"
assert result.result[0].content == "test"
def test_multiple_yields_in_retry_loop(self) -> None:
"""Test handler that retries multiple times."""
@@ -257,11 +273,11 @@ class TestChainModelCallHandlers:
attempt["value"] += 1
if attempt["value"] == 1:
raise ValueError("fail")
return AIMessage(content="ok")
return ModelResponse(result=[AIMessage(content="ok")])
result = composed(create_test_request(), mock_base_handler)
# Outer called once, inner retried so base handler called twice
assert call_count["value"] == 1
assert attempt["value"] == 2
assert result.content == "ok"
assert result.result[0].content == "ok"

View File

@@ -50,7 +50,9 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
hook_config,
ModelCall,
ModelRequest,
ModelResponse,
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
@@ -118,9 +120,9 @@ def test_create_agent_diagram(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
return handler(request.model_call)
def after_model(self, state, runtime):
pass
@@ -132,9 +134,9 @@ def test_create_agent_diagram(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
return handler(request)
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
return handler(request.model_call)
def after_model(self, state, runtime):
pass
@@ -260,10 +262,10 @@ def test_create_agent_invoke(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("NoopSeven.wrap_model_call")
return handler(request)
return handler(request.model_call)
def after_model(self, state, runtime):
calls.append("NoopSeven.after_model")
@@ -275,10 +277,10 @@ def test_create_agent_invoke(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("NoopEight.wrap_model_call")
return handler(request)
return handler(request.model_call)
def after_model(self, state, runtime):
calls.append("NoopEight.after_model")
@@ -361,10 +363,10 @@ def test_create_agent_jump(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("NoopSeven.wrap_model_call")
return handler(request)
return handler(request.model_call)
def after_model(self, state, runtime):
calls.append("NoopSeven.after_model")
@@ -378,10 +380,10 @@ def test_create_agent_jump(
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("NoopEight.wrap_model_call")
return handler(request)
return handler(request.model_call)
def after_model(self, state, runtime):
calls.append("NoopEight.after_model")
@@ -1032,46 +1034,54 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
assert middleware.ttl == "5m"
assert middleware.min_messages_to_cache == 0
fake_request = ModelRequest(
model_call = ModelCall(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
fake_request = ModelRequest(
model_call=model_call,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response", **req.model_settings)
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response", **req.model_settings)])
result = middleware.wrap_model_call(fake_request, mock_handler)
# Check that model_settings were passed through via the request
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
assert fake_request.model_call.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
from typing import cast
fake_request = ModelRequest(
model_call = ModelCall(
model=FakeToolCallingModel(),
messages=[HumanMessage("Hello")],
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
model_settings={},
)
fake_request = ModelRequest(
model_call=model_call,
state={"messages": [HumanMessage("Hello")]},
runtime=cast(Runtime, object()),
model_settings={},
)
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
with pytest.raises(
ValueError,
@@ -1102,12 +1112,12 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
in str(w[-1].message)
)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)
with warnings.catch_warnings(record=True) as w:
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)
assert len(w) == 1
assert (
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
@@ -1117,11 +1127,11 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
result = middleware.wrap_model_call(fake_request, mock_handler)
assert isinstance(result, AIMessage)
assert isinstance(result, ModelResponse)
# Tests for SummarizationMiddleware
@@ -1346,10 +1356,10 @@ def test_on_model_call() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
request.messages.append(HumanMessage("remember to be nice!"))
return handler(request)
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
request.model_call.messages.append(HumanMessage("remember to be nice!"))
return handler(request.model_call)
agent = create_agent(
model=FakeToolCallingModel(),
@@ -1482,10 +1492,10 @@ def test_runtime_injected_into_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
assert request.runtime is not None
return handler(request)
return handler(request.model_call)
def after_model(self, state: AgentState, runtime: Runtime) -> None:
assert runtime is not None
@@ -1581,25 +1591,28 @@ def test_planning_middleware_on_model_call(original_prompt, expected_prompt_pref
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
request = ModelRequest(
model_call = ModelCall(
model=model,
system_prompt=original_prompt,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=cast(Runtime, object()),
model_settings={},
)
request = ModelRequest(
model_call=model_call,
state=state,
runtime=cast(Runtime, object()),
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt.startswith(expected_prompt_prefix)
assert request.model_call.system_prompt.startswith(expected_prompt_prefix)
@pytest.mark.parametrize(
@@ -1725,13 +1738,13 @@ def test_planning_middleware_custom_system_prompt() -> None:
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
assert request.model_call.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
def test_planning_middleware_custom_tool_description() -> None:
@@ -1757,25 +1770,28 @@ def test_planning_middleware_custom_system_prompt_and_tool_description() -> None
model = FakeToolCallingModel()
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
request = ModelRequest(
model_call = ModelCall(
model=model,
system_prompt=None,
messages=[HumanMessage(content="Hello")],
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=cast(Runtime, object()),
model_settings={},
)
request = ModelRequest(
model_call=model_call,
state=state,
runtime=cast(Runtime, object()),
)
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
def mock_handler(req: ModelCall) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock response")])
# Call wrap_model_call to trigger the middleware logic
middleware.wrap_model_call(request, mock_handler)
# Check that the request was modified in place
assert request.system_prompt == custom_system_prompt
assert request.model_call.system_prompt == custom_system_prompt
# Verify tool description
assert len(middleware.tools) == 1
@@ -2047,11 +2063,11 @@ async def test_create_agent_async_invoke() -> None:
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
calls.append("AsyncMiddleware.awrap_model_call")
request.messages.append(HumanMessage("async middleware message"))
return await handler(request)
request.model_call.messages.append(HumanMessage("async middleware message"))
return await handler(request.model_call)
async def aafter_model(self, state, runtime) -> None:
calls.append("AsyncMiddleware.aafter_model")
@@ -2108,10 +2124,10 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
calls.append("AsyncMiddlewareOne.awrap_model_call")
return await handler(request)
return await handler(request.model_call)
async def aafter_model(self, state, runtime) -> None:
calls.append("AsyncMiddlewareOne.aafter_model")
@@ -2123,10 +2139,10 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
calls.append("AsyncMiddlewareTwo.awrap_model_call")
return await handler(request)
return await handler(request.model_call)
async def aafter_model(self, state, runtime) -> None:
calls.append("AsyncMiddlewareTwo.aafter_model")
@@ -2196,10 +2212,10 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("SyncMiddleware.wrap_model_call")
return handler(request)
return handler(request.model_call)
def after_model(self, state, runtime) -> None:
calls.append("SyncMiddleware.after_model")
@@ -2211,10 +2227,10 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
) -> ModelResponse:
calls.append("AsyncMiddleware.awrap_model_call")
return await handler(request)
return await handler(request.model_call)
async def aafter_model(self, state, runtime) -> None:
calls.append("AsyncMiddleware.aafter_model")
@@ -2267,11 +2283,11 @@ def test_wrap_model_call_hook() -> None:
def wrap_model_call(self, request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception:
# Retry on error
self.retry_count += 1
return handler(request)
return handler(request.model_call)
failing_model = FailingModel()
retry_middleware = RetryMiddleware()
@@ -2311,7 +2327,7 @@ def test_wrap_model_call_retry_count() -> None:
for attempt in range(max_retries):
self.attempts.append(attempt + 1)
try:
return handler(request)
return handler(request.model_call)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
@@ -2348,7 +2364,7 @@ def test_wrap_model_call_no_retry() -> None:
class NoRetryMiddleware(AgentMiddleware):
def wrap_model_call(self, request, handler):
return handler(request)
return handler(request.model_call)
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
@@ -2469,7 +2485,7 @@ def test_wrap_model_call_max_attempts() -> None:
for attempt in range(self.max_retries):
self.attempt_count += 1
try:
return handler(request)
return handler(request.model_call)
except Exception as e:
last_exception = e
# Continue to retry
@@ -2520,11 +2536,11 @@ async def test_wrap_model_call_async() -> None:
async def awrap_model_call(self, request, handler):
try:
return await handler(request)
return await handler(request.model_call)
except Exception:
# Retry on error
self.retry_count += 1
return await handler(request)
return await handler(request.model_call)
failing_model = AsyncFailingModel()
retry_middleware = AsyncRetryMiddleware()
@@ -2559,10 +2575,12 @@ def test_wrap_model_call_rewrite_response() -> None:
"""Middleware that rewrites the response."""
def wrap_model_call(self, request, handler):
result = handler(request)
result = handler(request.model_call)
# Rewrite the response
return AIMessage(content=f"REWRITTEN: {result.content}")
return ModelResponse(
result=[AIMessage(content=f"REWRITTEN: {result.result[0].content}")]
)
model = SimpleModel()
middleware = ResponseRewriteMiddleware()
@@ -2593,10 +2611,12 @@ def test_wrap_model_call_convert_error_to_response() -> None:
def wrap_model_call(self, request, handler):
try:
return handler(request)
return handler(request.model_call)
except Exception as e:
# Convert error to success response
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
return ModelResponse(
result=[AIMessage(content=f"Error occurred: {e}. Using fallback response.")]
)
model = AlwaysFailingModel()
middleware = ErrorToResponseMiddleware()
@@ -2619,7 +2639,7 @@ def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> N
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
) -> ModelResponse:
return await handler(request)
agent = create_agent(
@@ -2649,16 +2669,16 @@ def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
calls.append("MixedMiddleware.wrap_model_call")
return handler(request)
return handler(request.model_call)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
) -> AIMessage:
) -> ModelResponse:
calls.append("MixedMiddleware.awrap_model_call")
return await handler(request)

View File

@@ -13,7 +13,9 @@ from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCall,
ModelRequest,
ModelResponse,
before_model,
after_model,
dynamic_prompt,
@@ -89,8 +91,8 @@ def test_on_model_call_decorator() -> None:
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
def custom_on_model_call(request, handler):
request.system_prompt = "Modified"
return handler(request)
request.model_call.system_prompt = "Modified"
return handler(request.model_call)
# Verify all options were applied
assert isinstance(custom_on_model_call, AgentMiddleware)
@@ -99,22 +101,27 @@ def test_on_model_call_decorator() -> None:
assert custom_on_model_call.__class__.__name__ == "CustomOnModelCall"
# Verify it works
original_request = ModelRequest(
model_call = ModelCall(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
original_request = ModelRequest(
model_call=model_call,
state={"messages": [HumanMessage("Hello")]},
runtime=None,
)
def mock_handler(req):
return AIMessage(content=f"Handled with prompt: {req.system_prompt}")
return ModelResponse(
result=[AIMessage(content=f"Handled with prompt: {req.system_prompt}")]
)
result = custom_on_model_call.wrap_model_call(original_request, mock_handler)
assert result.content == "Handled with prompt: Modified"
assert result.result[0].content == "Handled with prompt: Modified"
def test_all_decorators_integration() -> None:
@@ -129,7 +136,7 @@ def test_all_decorators_integration() -> None:
@wrap_model_call
def track_on_call(request, handler):
call_order.append("on_call")
return handler(request)
return handler(request.model_call)
@after_model
def track_after(state: AgentState, runtime: Runtime) -> None:
@@ -324,7 +331,7 @@ async def test_async_decorators_integration() -> None:
@wrap_model_call
async def track_async_on_call(request, handler):
call_order.append("async_on_call")
return await handler(request)
return await handler(request.model_call)
@after_model
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
@@ -364,7 +371,7 @@ async def test_mixed_sync_async_decorators_integration() -> None:
@wrap_model_call
async def track_async_on_call(request, handler):
call_order.append("async_on_call")
return await handler(request)
return await handler(request.model_call)
@after_model
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
@@ -581,22 +588,25 @@ def test_dynamic_prompt_decorator() -> None:
assert my_prompt.__class__.__name__ == "my_prompt"
# Verify it modifies the request correctly
original_request = ModelRequest(
model_call = ModelCall(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
original_request = ModelRequest(
model_call=model_call,
state={"messages": [HumanMessage("Hello")]},
runtime=None,
)
def mock_handler(req):
return AIMessage(content=req.system_prompt)
return ModelResponse(result=[AIMessage(content=req.system_prompt)])
result = my_prompt.wrap_model_call(original_request, mock_handler)
assert result.content == "Dynamic test prompt"
assert result.result[0].content == "Dynamic test prompt"
def test_dynamic_prompt_uses_state() -> None:
@@ -608,22 +618,25 @@ def test_dynamic_prompt_uses_state() -> None:
return f"Prompt with {msg_count} messages"
# Verify it uses state correctly
original_request = ModelRequest(
model_call = ModelCall(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
original_request = ModelRequest(
model_call=model_call,
state={"messages": [HumanMessage("Hello"), HumanMessage("World")]},
runtime=None,
)
def mock_handler(req):
return AIMessage(content=req.system_prompt)
return ModelResponse(result=[AIMessage(content=req.system_prompt)])
result = custom_prompt.wrap_model_call(original_request, mock_handler)
assert result.content == "Prompt with 2 messages"
assert result.result[0].content == "Prompt with 2 messages"
def test_dynamic_prompt_integration() -> None:

View File

@@ -3,7 +3,13 @@
import pytest
from collections.abc import Callable
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCall,
ModelRequest,
ModelResponse,
)
from langchain.agents.factory import create_agent
from langchain.tools import ToolNode
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
@@ -30,10 +36,10 @@ def test_model_request_tools_are_base_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
captured_requests.append(request)
return handler(request)
return handler(request.model_call)
agent = create_agent(
model=FakeToolCallingModel(),
@@ -49,9 +55,9 @@ def test_model_request_tools_are_base_tools() -> None:
# Check that tools in the request are BaseTool objects
request = captured_requests[0]
assert isinstance(request.tools, list)
assert len(request.tools) == 2
assert {t.name for t in request.tools} == {"search_tool", "calculator"}
assert isinstance(request.model_call.tools, list)
assert len(request.model_call.tools) == 2
assert {t.name for t in request.model_call.tools} == {"search_tool", "calculator"}
def test_middleware_can_modify_tools() -> None:
@@ -76,11 +82,13 @@ def test_middleware_can_modify_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Only allow tool_a and tool_b
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
return handler(request)
request.model_call.tools = [
t for t in request.model_call.tools if t.name in ["tool_a", "tool_b"]
]
return handler(request.model_call)
# Model will try to call tool_a
model = FakeToolCallingModel(
@@ -121,11 +129,11 @@ def test_unknown_tool_raises_error() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Add an unknown tool
request.tools = request.tools + [unknown_tool]
return handler(request)
request.model_call.tools = request.model_call.tools + [unknown_tool]
return handler(request.model_call)
agent = create_agent(
model=FakeToolCallingModel(),
@@ -160,12 +168,14 @@ def test_middleware_can_add_and_remove_tools() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Remove admin_tool if not admin
if not request.state.get("is_admin", False):
request.tools = [t for t in request.tools if t.name != "admin_tool"]
return handler(request)
request.model_call.tools = [
t for t in request.model_call.tools if t.name != "admin_tool"
]
return handler(request.model_call)
model = FakeToolCallingModel()
@@ -198,11 +208,11 @@ def test_empty_tools_list_is_valid() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Remove all tools
request.tools = []
return handler(request)
request.model_call.tools = []
return handler(request.model_call)
model = FakeToolCallingModel()
@@ -241,25 +251,25 @@ def test_tools_preserved_across_multiple_middleware() -> None:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
modification_order.append([t.name for t in request.model_call.tools])
# Remove tool_c
request.tools = [t for t in request.tools if t.name != "tool_c"]
return handler(request)
request.model_call.tools = [t for t in request.model_call.tools if t.name != "tool_c"]
return handler(request.model_call)
class SecondMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], AIMessage],
) -> AIMessage:
modification_order.append([t.name for t in request.tools])
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
modification_order.append([t.name for t in request.model_call.tools])
# Should not see tool_c here
assert all(t.name != "tool_c" for t in request.tools)
assert all(t.name != "tool_c" for t in request.model_call.tools)
# Remove tool_b
request.tools = [t for t in request.tools if t.name != "tool_b"]
return handler(request)
request.model_call.tools = [t for t in request.model_call.tools if t.name != "tool_b"]
return handler(request.model_call)
agent = create_agent(
model=FakeToolCallingModel(),

View File

@@ -698,7 +698,12 @@ class TestDynamicModelWithResponseFormat:
selected based on the final model's capabilities.
"""
from unittest.mock import patch
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCall,
ModelRequest,
ModelResponse,
)
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
# Custom model that we'll use to test whether the tool strategy is applied
@@ -730,11 +735,11 @@ class TestDynamicModelWithResponseFormat:
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], CoreAIMessage],
) -> CoreAIMessage:
handler: Callable[[ModelCall], ModelResponse],
) -> ModelResponse:
# Replace the model with our custom test model
request.model = model
return handler(request)
request.model_call.model = model
return handler(request.model_call)
# Track which model is checked for provider strategy support
calls = []