diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 43d57ab1ba5..82f5f0e4490 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -7,7 +7,7 @@ from langgraph.runtime import Runtime from langgraph.types import interrupt from typing_extensions import NotRequired, TypedDict -from langchain.agents.middleware.types import AgentMiddleware, AgentState +from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT class Action(TypedDict): @@ -102,7 +102,7 @@ class HITLResponse(TypedDict): class _DescriptionFactory(Protocol): """Callable that generates a description for a tool call.""" - def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str: + def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str: """Generate a description for a tool call.""" ... @@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict): def format_tool_description( tool_call: ToolCall, state: AgentState, - runtime: Runtime + runtime: Runtime[ContextT] ) -> str: import json return ( @@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict): """JSON schema for the args associated with the action, if edits are allowed.""" -class HumanInTheLoopMiddleware(AgentMiddleware): +class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]): """Human in the loop middleware.""" def __init__( @@ -204,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): tool_call: ToolCall, config: InterruptOnConfig, state: AgentState, - runtime: Runtime, + runtime: Runtime[ContextT], ) -> tuple[ActionRequest, ReviewConfig]: """Create an ActionRequest and ReviewConfig for a tool call.""" tool_name = tool_call["name"] @@ -277,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware): ) raise ValueError(msg) - def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None: """Trigger interrupt flows for relevant tool calls after an `AIMessage`.""" messages = state["messages"] if not messages: @@ -350,6 +350,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware): return {"messages": [last_ai_msg, *artificial_tool_messages]} - async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + async def aafter_model( + self, state: AgentState, runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Async trigger interrupt flows for relevant tool calls after an `AIMessage`.""" return self.after_model(state, runtime)