fix(langchain): use state schema as input schema to middleware nodes (#33023)

We want state schema as the input schema to middleware nodes because the
conditional edges after these nodes need access to the full state.

Also, we just generally want all state passed to middleware nodes, so we
should be specifying this explicitly. If we don't, the state annotations
used by users in their node signatures are used (so they might be
missing fields).
This commit is contained in:
Sydney Runkle
2025-09-19 14:43:33 -04:00
committed by GitHub
parent 4d118777bc
commit c3654202a3

View File

@@ -226,13 +226,17 @@ def create_agent( # noqa: PLR0915
state_schemas = {m.state_schema for m in middleware}
state_schemas.add(AgentState)
state_schema = _resolve_schema(state_schemas, "StateSchema", None)
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
# create graph, add nodes
graph: StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
] = StateGraph(
state_schema=_resolve_schema(state_schemas, "StateSchema", None),
input_schema=_resolve_schema(state_schemas, "InputSchema", "input"),
output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"),
state_schema=state_schema,
input_schema=input_schema,
output_schema=output_schema,
context_schema=context_schema,
)
@@ -417,16 +421,12 @@ def create_agent( # noqa: PLR0915
for m in middleware:
if m.__class__.before_model is not AgentMiddleware.before_model:
graph.add_node(
f"{m.__class__.__name__}.before_model",
m.before_model,
input_schema=m.state_schema,
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
)
if m.__class__.after_model is not AgentMiddleware.after_model:
graph.add_node(
f"{m.__class__.__name__}.after_model",
m.after_model,
input_schema=m.state_schema,
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
)
# add start edge