diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index f94ab417ba7..718e4a692e6 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -226,13 +226,17 @@ def create_agent( # noqa: PLR0915 state_schemas = {m.state_schema for m in middleware} state_schemas.add(AgentState) + state_schema = _resolve_schema(state_schemas, "StateSchema", None) + input_schema = _resolve_schema(state_schemas, "InputSchema", "input") + output_schema = _resolve_schema(state_schemas, "OutputSchema", "output") + # create graph, add nodes graph: StateGraph[ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT] ] = StateGraph( - state_schema=_resolve_schema(state_schemas, "StateSchema", None), - input_schema=_resolve_schema(state_schemas, "InputSchema", "input"), - output_schema=_resolve_schema(state_schemas, "OutputSchema", "output"), + state_schema=state_schema, + input_schema=input_schema, + output_schema=output_schema, context_schema=context_schema, ) @@ -417,16 +421,12 @@ def create_agent( # noqa: PLR0915 for m in middleware: if m.__class__.before_model is not AgentMiddleware.before_model: graph.add_node( - f"{m.__class__.__name__}.before_model", - m.before_model, - input_schema=m.state_schema, + f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema ) if m.__class__.after_model is not AgentMiddleware.after_model: graph.add_node( - f"{m.__class__.__name__}.after_model", - m.after_model, - input_schema=m.state_schema, + f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema ) # add start edge