From c3654202a337d43782a88e99f11371181d1ebeb6 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Fri, 19 Sep 2025 14:43:33 -0400 Subject: [PATCH] fix(langchain): use state schema as input schema to middleware nodes (#33023) We want state schema as the input schema to middleware nodes because the conditional edges after these nodes need access to the full state. Also, we just generally want all state passed to middleware nodes, so we should be specifying this explicitly. If we don't, the state annotations used by users in their node signatures are used (so they might be missing fields). --- .../langchain/agents/middleware_agent.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index f94ab417ba7..718e4a692e6 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -226,13 +226,17 @@ def create_agent( # noqa: PLR0915 state_schemas = {m.state_schema for m in middleware} state_schemas.add(AgentState) + state_schema = _resolve_schema(state_schemas, "StateSchema", None) + input_schema = _resolve_schema(state_schemas, "InputSchema", "input") + output_schema = _resolve_schema(state_schemas, "OutputSchema", "output") + # create graph, add nodes graph: StateGraph[ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ] = StateGraph( - state_schema=_resolve_schema(state_schemas, "StateSchema", None), - input_schema=_resolve_schema(state_schemas, "InputSchema", "input"), - output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"), + state_schema=state_schema, + input_schema=input_schema, + output_schema=output_schema, context_schema=context_schema, ) @@ -417,16 +421,12 @@ def create_agent( # noqa: PLR0915 for m in middleware: if m.__class__.before_model is not AgentMiddleware.before_model: graph.add_node( - f"{m.__class__.__name__}.before_model", - m.before_model, - input_schema=m.state_schema, + f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema ) if m.__class__.after_model is not AgentMiddleware.after_model: graph.add_node( - f"{m.__class__.__name__}.after_model", - m.after_model, - input_schema=m.state_schema, + f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema ) # add start edge