diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index bf46b9cc8e5..53f94be3d60 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence from typing import Any from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage +from langchain_core.messages import AIMessage, SystemMessage, ToolMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langgraph.constants import END, START @@ -211,24 +211,6 @@ def create_agent( # noqa: PLR0915 context_schema=context_schema, ) - def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]: - """Prepare model request and messages.""" - request = state.get("model_request") or ModelRequest( - model=model, - tools=default_tools, - system_prompt=system_prompt, - response_format=response_format, - messages=state["messages"], - tool_choice=None, - ) - - # prepare messages - messages = request.messages - if request.system_prompt: - messages = [SystemMessage(request.system_prompt), *messages] - - return request, messages - def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]: """Handle model output including structured responses.""" # Handle structured output with native strategy @@ -342,8 +324,14 @@ def create_agent( # noqa: PLR0915 def model_request(state: dict[str, Any]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" - # Start with the base model request - request, messages = _prepare_model_request(state) + request = ModelRequest( + model=model, + tools=default_tools, + system_prompt=system_prompt, + response_format=response_format, + messages=state["messages"], + tool_choice=None, + ) # Apply modify_model_request middleware in sequence for m in middleware_w_modify_model_request: @@ -351,15 +339,26 @@ def create_agent( # noqa: PLR0915 filtered_state = _filter_state_for_schema(state, m.state_schema) request = m.modify_model_request(request, filtered_state) - # Get the bound model with the final request + # Get the final model and messages model_ = _get_bound_model(request) + messages = request.messages + if request.system_prompt: + messages = [SystemMessage(request.system_prompt), *messages] + output = model_.invoke(messages) return _handle_model_output(state, output) async def amodel_request(state: dict[str, Any]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" # Start with the base model request - request, messages = _prepare_model_request(state) + request = ModelRequest( + model=model, + tools=default_tools, + system_prompt=system_prompt, + response_format=response_format, + messages=state["messages"], + tool_choice=None, + ) # Apply modify_model_request middleware in sequence for m in middleware_w_modify_model_request: @@ -367,8 +366,12 @@ def create_agent( # noqa: PLR0915 filtered_state = _filter_state_for_schema(state, m.state_schema) request = m.modify_model_request(request, filtered_state) - # Get the bound model with the final request + # Get the final model and messages model_ = _get_bound_model(request) + messages = request.messages + if request.system_prompt: + messages = [SystemMessage(request.system_prompt), *messages] + output = await model_.ainvoke(messages) return _handle_model_output(state, output) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 02956f612ee..830caa81a07 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -19,7 +19,8 @@ from langchain.agents.middleware_agent import create_agent from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware from langchain.agents.middleware.summarization import SummarizationMiddleware -from langchain.agents.middleware.types import AgentMiddleware, ModelRequest +from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState +from langchain_core.tools import BaseTool from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.memory import InMemorySaver @@ -710,3 +711,25 @@ def test_summarization_middleware_full_workflow() -> None: assert summary_message is not None assert "Generated summary" in summary_message.content + + +def test_modify_model_request() -> None: + class ModifyMiddleware(AgentMiddleware): + def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: + request.messages.append(HumanMessage("remember to be nice!")) + return request + + builder = create_agent( + model=FakeToolCallingModel(), + tools=[], + system_prompt="You are a helpful assistant.", + middleware=[ModifyMiddleware()], + ) + + agent = builder.compile() + result = agent.invoke({"messages": [HumanMessage("Hello")]}) + assert result["messages"][0].content == "Hello" + assert result["messages"][1].content == "remember to be nice!" + assert ( + result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!" + )