fix(langchain): use messages from model request (#32908)

Oversight when moving back to basic function call for
`modify_model_request` rather than implementation as its own node.

Basic test right now failing on main, passing on this branch

Revealed a gap in testing. Will write up a more robust test suite for
basic middleware features.
This commit is contained in:
Sydney Runkle
2025-09-12 08:18:02 -04:00
committed by GitHub
parent 649d8a8223
commit 9e78ff19ab
2 changed files with 51 additions and 25 deletions

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence
from typing import Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.constants import END, START
@@ -211,24 +211,6 @@ def create_agent( # noqa: PLR0915
context_schema=context_schema,
)
def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
"""Prepare model request and messages."""
request = state.get("model_request") or ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
# prepare messages
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
return request, messages
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
"""Handle model output including structured responses."""
# Handle structured output with native strategy
@@ -342,8 +324,14 @@ def create_agent( # noqa: PLR0915
def model_request(state: dict[str, Any]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
# Start with the base model request
request, messages = _prepare_model_request(state)
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
@@ -351,15 +339,26 @@ def create_agent( # noqa: PLR0915
filtered_state = _filter_state_for_schema(state, m.state_schema)
request = m.modify_model_request(request, filtered_state)
# Get the bound model with the final request
# Get the final model and messages
model_ = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
output = model_.invoke(messages)
return _handle_model_output(state, output)
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Start with the base model request
request, messages = _prepare_model_request(state)
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
@@ -367,8 +366,12 @@ def create_agent( # noqa: PLR0915
filtered_state = _filter_state_for_schema(state, m.state_schema)
request = m.modify_model_request(request, filtered_state)
# Get the bound model with the final request
# Get the final model and messages
model_ = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
output = await model_.ainvoke(messages)
return _handle_model_output(state, output)

View File

@@ -19,7 +19,8 @@ from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
from langchain_core.tools import BaseTool
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import InMemorySaver
@@ -710,3 +711,25 @@ def test_summarization_middleware_full_workflow() -> None:
assert summary_message is not None
assert "Generated summary" in summary_message.content
def test_modify_model_request() -> None:
class ModifyMiddleware(AgentMiddleware):
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
request.messages.append(HumanMessage("remember to be nice!"))
return request
builder = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[ModifyMiddleware()],
)
agent = builder.compile()
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result["messages"][0].content == "Hello"
assert result["messages"][1].content == "remember to be nice!"
assert (
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
)