fix(langchain): add types to human_in_the_loop middleware (#34137)

The `HumanInTheLoopMiddleware` is missing a type annotation for the
context schema. Without the fix in this PR, the following code does not
type check:

```
graph = create_agent(
    "gpt-5",
    tools=[send_email_tool, read_email_tool],
    middleware=[
        HumanInTheLoopMiddleware(
            interrupt_on={
                # Require approval or rejection for sending emails
                "send_email_tool": {
                    "allowed_decisions": ["approve", "reject"],
                },
                # Auto-approve reading emails
                "read_email_tool": False,
            }
        ),
    ],
    context_schema=ContextSchema,
)
```

```
Argument of type "list[HumanInTheLoopMiddleware]" cannot be assigned to parameter "middleware" of type "Sequence[AgentMiddleware[StateT_co@create_agent, ContextT@create_agent]]" in function "create_agent"
  "HumanInTheLoopMiddleware" is not assignable to "AgentMiddleware[AgentState[Unknown], ContextSchema | None]"
    Type parameter "ContextT@AgentMiddleware" is invariant, but "None" is not the same as "ContextSchema | None"
```
This commit is contained in:
Steffen Hausmann
2025-12-01 14:46:38 +01:00
committed by GitHub
parent 12df938ace
commit 4065106c2e

View File

@@ -7,7 +7,7 @@ from langgraph.runtime import Runtime
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
class Action(TypedDict):
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
class _DescriptionFactory(Protocol):
"""Callable that generates a description for a tool call."""
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
"""Generate a description for a tool call."""
...
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
def format_tool_description(
tool_call: ToolCall,
state: AgentState,
runtime: Runtime
runtime: Runtime[ContextT]
) -> str:
import json
return (
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
"""JSON schema for the args associated with the action, if edits are allowed."""
class HumanInTheLoopMiddleware(AgentMiddleware):
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
"""Human in the loop middleware."""
def __init__(
@@ -204,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState,
runtime: Runtime,
runtime: Runtime[ContextT],
) -> tuple[ActionRequest, ReviewConfig]:
"""Create an ActionRequest and ReviewConfig for a tool call."""
tool_name = tool_call["name"]
@@ -277,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
)
raise ValueError(msg)
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
messages = state["messages"]
if not messages:
@@ -350,6 +350,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
return {"messages": [last_ai_msg, *artificial_tool_messages]}
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
async def aafter_model(
self, state: AgentState, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
return self.after_model(state, runtime)