mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
fix(langchain): use messages from model request (#32908)
Oversight when moving back to basic function call for `modify_model_request` rather than implementation as its own node. Basic test right now failing on main, passing on this branch Revealed a gap in testing. Will write up a more robust test suite for basic middleware features.
This commit is contained in:
@@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
@@ -211,24 +211,6 @@ def create_agent( # noqa: PLR0915
|
|||||||
context_schema=context_schema,
|
context_schema=context_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
|
|
||||||
"""Prepare model request and messages."""
|
|
||||||
request = state.get("model_request") or ModelRequest(
|
|
||||||
model=model,
|
|
||||||
tools=default_tools,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
response_format=response_format,
|
|
||||||
messages=state["messages"],
|
|
||||||
tool_choice=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare messages
|
|
||||||
messages = request.messages
|
|
||||||
if request.system_prompt:
|
|
||||||
messages = [SystemMessage(request.system_prompt), *messages]
|
|
||||||
|
|
||||||
return request, messages
|
|
||||||
|
|
||||||
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
|
||||||
"""Handle model output including structured responses."""
|
"""Handle model output including structured responses."""
|
||||||
# Handle structured output with native strategy
|
# Handle structured output with native strategy
|
||||||
@@ -342,8 +324,14 @@ def create_agent( # noqa: PLR0915
|
|||||||
|
|
||||||
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
def model_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Sync model request handler with sequential middleware processing."""
|
"""Sync model request handler with sequential middleware processing."""
|
||||||
# Start with the base model request
|
request = ModelRequest(
|
||||||
request, messages = _prepare_model_request(state)
|
model=model,
|
||||||
|
tools=default_tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_format=response_format,
|
||||||
|
messages=state["messages"],
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply modify_model_request middleware in sequence
|
# Apply modify_model_request middleware in sequence
|
||||||
for m in middleware_w_modify_model_request:
|
for m in middleware_w_modify_model_request:
|
||||||
@@ -351,15 +339,26 @@ def create_agent( # noqa: PLR0915
|
|||||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||||
request = m.modify_model_request(request, filtered_state)
|
request = m.modify_model_request(request, filtered_state)
|
||||||
|
|
||||||
# Get the bound model with the final request
|
# Get the final model and messages
|
||||||
model_ = _get_bound_model(request)
|
model_ = _get_bound_model(request)
|
||||||
|
messages = request.messages
|
||||||
|
if request.system_prompt:
|
||||||
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
|
|
||||||
output = model_.invoke(messages)
|
output = model_.invoke(messages)
|
||||||
return _handle_model_output(state, output)
|
return _handle_model_output(state, output)
|
||||||
|
|
||||||
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Async model request handler with sequential middleware processing."""
|
"""Async model request handler with sequential middleware processing."""
|
||||||
# Start with the base model request
|
# Start with the base model request
|
||||||
request, messages = _prepare_model_request(state)
|
request = ModelRequest(
|
||||||
|
model=model,
|
||||||
|
tools=default_tools,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
response_format=response_format,
|
||||||
|
messages=state["messages"],
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply modify_model_request middleware in sequence
|
# Apply modify_model_request middleware in sequence
|
||||||
for m in middleware_w_modify_model_request:
|
for m in middleware_w_modify_model_request:
|
||||||
@@ -367,8 +366,12 @@ def create_agent( # noqa: PLR0915
|
|||||||
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
filtered_state = _filter_state_for_schema(state, m.state_schema)
|
||||||
request = m.modify_model_request(request, filtered_state)
|
request = m.modify_model_request(request, filtered_state)
|
||||||
|
|
||||||
# Get the bound model with the final request
|
# Get the final model and messages
|
||||||
model_ = _get_bound_model(request)
|
model_ = _get_bound_model(request)
|
||||||
|
messages = request.messages
|
||||||
|
if request.system_prompt:
|
||||||
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
|
|
||||||
output = await model_.ainvoke(messages)
|
output = await model_.ainvoke(messages)
|
||||||
return _handle_model_output(state, output)
|
return _handle_model_output(state, output)
|
||||||
|
|
||||||
|
@@ -19,7 +19,8 @@ from langchain.agents.middleware_agent import create_agent
|
|||||||
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
|
||||||
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
|
||||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest, AgentState
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from langgraph.checkpoint.base import BaseCheckpointSaver
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
@@ -710,3 +711,25 @@ def test_summarization_middleware_full_workflow() -> None:
|
|||||||
|
|
||||||
assert summary_message is not None
|
assert summary_message is not None
|
||||||
assert "Generated summary" in summary_message.content
|
assert "Generated summary" in summary_message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_model_request() -> None:
|
||||||
|
class ModifyMiddleware(AgentMiddleware):
|
||||||
|
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||||
|
request.messages.append(HumanMessage("remember to be nice!"))
|
||||||
|
return request
|
||||||
|
|
||||||
|
builder = create_agent(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
tools=[],
|
||||||
|
system_prompt="You are a helpful assistant.",
|
||||||
|
middleware=[ModifyMiddleware()],
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = builder.compile()
|
||||||
|
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||||
|
assert result["messages"][0].content == "Hello"
|
||||||
|
assert result["messages"][1].content == "remember to be nice!"
|
||||||
|
assert (
|
||||||
|
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user