mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
fix(langchain): add types to human_in_the_loop middleware (#34137)
The `HumanInTheLoopMiddleware` is missing a type annotation for the
context schema. Without the fix in this PR, the following code does not
type check:
```
graph = create_agent(
"gpt-5",
tools=[send_email_tool, read_email_tool],
middleware=[
HumanInTheLoopMiddleware(
interrupt_on={
# Require approval or rejection for sending emails
"send_email_tool": {
"allowed_decisions": ["approve", "reject"],
},
# Auto-approve reading emails
"read_email_tool": False,
}
),
],
context_schema=ContextSchema,
)
```
```
Argument of type "list[HumanInTheLoopMiddleware]" cannot be assigned to parameter "middleware" of type "Sequence[AgentMiddleware[StateT_co@create_agent, ContextT@create_agent]]" in function "create_agent"
"HumanInTheLoopMiddleware" is not assignable to "AgentMiddleware[AgentState[Unknown], ContextSchema | None]"
Type parameter "ContextT@AgentMiddleware" is invariant, but "None" is not the same as "ContextSchema | None"
```
This commit is contained in:
@@ -7,7 +7,7 @@ from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
|
||||
|
||||
|
||||
class Action(TypedDict):
|
||||
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
|
||||
class _DescriptionFactory(Protocol):
|
||||
"""Callable that generates a description for a tool call."""
|
||||
|
||||
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
||||
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
||||
"""Generate a description for a tool call."""
|
||||
...
|
||||
|
||||
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
|
||||
def format_tool_description(
|
||||
tool_call: ToolCall,
|
||||
state: AgentState,
|
||||
runtime: Runtime
|
||||
runtime: Runtime[ContextT]
|
||||
) -> str:
|
||||
import json
|
||||
return (
|
||||
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
|
||||
"""JSON schema for the args associated with the action, if edits are allowed."""
|
||||
|
||||
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
||||
"""Human in the loop middleware."""
|
||||
|
||||
def __init__(
|
||||
@@ -204,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
tool_call: ToolCall,
|
||||
config: InterruptOnConfig,
|
||||
state: AgentState,
|
||||
runtime: Runtime,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> tuple[ActionRequest, ReviewConfig]:
|
||||
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
||||
tool_name = tool_call["name"]
|
||||
@@ -277,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
@@ -350,6 +350,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
Reference in New Issue
Block a user