mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
fix(langchain): use messages from model request (#32908)
Oversight when moving back to basic function call for `modify_model_request` rather than implementation as its own node. Basic test right now failing on main, passing on this branch Revealed a gap in testing. Will write up a more robust test suite for basic middleware features.
This commit is contained in:
@@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
@@ -211,24 +211,6 @@ def create_agent( # noqa: PLR0915
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
|
||||
"""Prepare model request and messages."""
|
||||
request = state.get("model_request") or ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# prepare messages
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
return request, messages
|
||||
|
||||
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
||||
"""Handle model output including structured responses."""
|
||||
# Handle structured output with native strategy
|
||||
@@ -342,8 +324,14 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request, messages = _prepare_model_request(state)
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
@@ -351,15 +339,26 @@ def create_agent( # noqa: PLR0915
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
|
||||
# Get the bound model with the final request
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
|
||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request, messages = _prepare_model_request(state)
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
@@ -367,8 +366,12 @@ def create_agent( # noqa: PLR0915
|
||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||
request = m.modify_model_request(request, filtered_state)
|
||||
|
||||
# Get the bound model with the final request
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return _handle_model_output(state, output)
|
||||
|
||||
|
@@ -19,7 +19,8 @@ from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
@@ -710,3 +711,25 @@ def test_summarization_middleware_full_workflow() -> None:
|
||||
|
||||
assert summary_message is not None
|
||||
assert "Generated summary" in summary_message.content
|
||||
|
||||
|
||||
def test_modify_model_request() -> None:
|
||||
class ModifyMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
request.messages.append(HumanMessage("remember to be nice!"))
|
||||
return request
|
||||
|
||||
builder = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[ModifyMiddleware()],
|
||||
)
|
||||
|
||||
agent = builder.compile()
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result["messages"][0].content == "Hello"
|
||||
assert result["messages"][1].content == "remember to be nice!"
|
||||
assert (
|
||||
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
||||
)
|
||||
|
Reference in New Issue
Block a user