new middleware

This commit is contained in:
Harrison Chase
2025-09-07 12:26:43 +02:00
parent 44a60a6f09
commit dccfcd3f08
14 changed files with 741 additions and 13 deletions

View File

@@ -0,0 +1,9 @@
def _generate_correction_tool_messages(content: str, tool_calls: list):
tool_messages = []
for tool_call in tool_calls:
tool_messages.append({
"role": "tool",
"content": content,
"tool_call_id": tool_call["id"]
})
return tool_messages

View File

@@ -0,0 +1,68 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
from typing import NotRequired, Annotated
from typing import Literal
from typing_extensions import TypedDict
class Todo(TypedDict):
"""Todo to track."""
content: str
status: Literal["pending", "in_progress", "completed"]
def file_reducer(l, r):
if l is None:
return r
elif r is None:
return l
else:
return {**l, **r}
class DeepAgentState(AgentState):
todos: NotRequired[list[Todo]]
files: Annotated[NotRequired[dict[str, str]], file_reducer]
from langchain_core.tools import tool, InjectedToolCallId
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from typing import Annotated, Union
from langgraph.prebuilt import InjectedState
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
"""write todos"""
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
],
}
)
def ls(state: Annotated[DeepAgentState, InjectedState]) -> list[str]:
"""List all files"""
return list(state.get("files", {}).keys())
class DeepAgentMiddleware(AgentMiddleware):
state_schema = DeepAgentState
def __init__(self, subagents: list = []):
self.subagents = subagents
@property
def tools(self):
return [write_todos, ls] + self.subagents
def modify_model_request(self, request: ModelRequest, state: DeepAgentState) -> ModelRequest:
if request.system_prompt:
request.system_prompt += "\n\nUse the todo tool to plan as needed"
else:
request.system_prompt = "Use the todo tool to plan as needed"
return request

View File

@@ -0,0 +1,22 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
class DynamicPrompt(AgentMiddleware):
def __init__(self, modifier):
self.modifier = modifier
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
prompt = self.modifier(state)
request.system_prompt = prompt
return request
class DynamicMessages(AgentMiddleware):
def __init__(self, modifier):
self.modifier = modifier
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
messages = self.modifier(state)
request.messages = messages
return request

View File

@@ -7,6 +7,7 @@ from langgraph.prebuilt.interrupt import (
from langgraph.types import interrupt
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
from langchain.agents.middleware._utils import _generate_correction_tool_messages
ToolInterruptConfig = dict[str, HumanInterruptConfig]
@@ -50,11 +51,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
# Right now, we do not support multiple tool calls with interrupts
if len(interrupt_tool_calls) > 1:
raise ValueError("Does not currently support multiple tool calls with interrupts")
tool_names = [t['name'] for t in interrupt_tool_calls]
msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time"
return {
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
"jump_to": "model"
}
# Right now, we do not support interrupting a tool call if other tool calls exist
if auto_approved_tool_calls:
raise ValueError("Does not currently support interrupting a tool call if other tool calls exist")
tool_names = [t['name'] for t in interrupt_tool_calls]
msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
# Only one tool call will need interrupts
tool_call = interrupt_tool_calls[0]
@@ -87,14 +99,14 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
}
approved_tool_calls.append(new_tool_call)
elif response["type"] == "ignore":
return {"goto": "__end__"}
return {"jump_to": "__end__"}
elif response["type"] == "response":
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": response["args"],
}
return {"messages": [tool_message], "goto": "model"}
return {"messages": [tool_message], "jump_to": "model"}
else:
raise ValueError(f"Unknown response type: {response['type']}")

View File

@@ -0,0 +1,145 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
_HANDBACK_NAME = "hand_back"
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str | None
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_pass_back_tool(self):
@tool(name_or_callable=_HANDBACK_NAME,
description="Call this if you are unable to handle the current request. You will hand back control of the conversation to your supervisor")
def hand_back():
pass
return hand_back
def _get_main_handoff_tools(self):
tools = []
for agent in self.agents:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
if state.get('active_agent') is None:
request.tools = request.tools + self._get_main_handoff_tools()
return request
active_agent = self.agent_mapping[state['active_agent']]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools'] + self._get_handoff_tool()
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
active_agent = state.get('active_agent')
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None
if active_agent is not None:
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] == _HANDBACK_NAME:
handoffs.append(tool_call)
if len(handoffs) == 0:
return None
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
return {
"messages": [{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed back to supervisor",
}],
"active_agent": None,
"jump_to": "model"
}
handoff_tools = self._get_main_handoff_tools()
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -0,0 +1,35 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentJump, AgentUpdate
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import tool
import uuid
class RAGMiddleware(AgentMiddleware):
@classmethod
def from_retriever(cls, retriever: BaseRetriever, description: str):
@tool(description=description)
def retrieve(query: str):
return retriever.get_relevant_documents(query)
return cls(retrieve)
def __init__(self, tool):
self.tool = tool
@property
def tools(self):
return [self.tool]
def before_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
if len(state['messages']) == 1:
forced_tool_call = {
"type": "tool_call",
"name": self.tool.name,
"args": {"query": state['messages'][0].content},
"id": str(uuid.uuid4()),
}
return {
"messages": [{"role": "ai", "content": None, "tool_calls": [forced_tool_call]}],
"jump_to": "tools"
}

View File

@@ -0,0 +1,14 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
class ReflectionMiddleware(AgentMiddleware):
def __init__(self, reflection_step):
self.reflection_step = reflection_step
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
reflection = self.reflection_step(state)
if reflection:
return {
"messages": [{'role': 'user', 'content': reflection}],
"jump_to": "model"
}

View File

@@ -0,0 +1,129 @@
import uuid
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str | None
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_main_handoff_tools(self):
tools = []
for agent in self.agents:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
if state.get('active_agent') is None:
request.tools = request.tools + self._get_main_handoff_tools()
return request
active_agent = self.agent_mapping[state['active_agent']]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools']
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
active_agent = state.get('active_agent')
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
if active_agent is None:
return None
else:
fake_tool_call_id = str(uuid.uuid4())
last_message.tool_calls = [{
"type": "tool_call",
"name": "hand_back_to_supervisor",
"args": {},
"id": fake_tool_call_id,
}]
fake_tool_message = {"role": "tool", "content": "Handed back to supervisor", "tool_call_id": fake_tool_call_id}
return {
"messages": [last_message, fake_tool_message],
"jump_to": "model"
}
if active_agent is not None:
return None
handoff_tools = self._get_main_handoff_tools()
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -0,0 +1,111 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_handoff_tools(self, active_agent: str):
tools = []
for agent in self.agents:
if agent['name'] != active_agent:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
active_agent_name = state.get('active_agent') or self.starting_agent
active_agent = self.agent_mapping[active_agent_name]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools'] + self._get_handoff_tools(active_agent)
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
active_agent = state.get('active_agent') or self.starting_agent
messages = state["messages"]
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None
handoff_tools = self._get_handoff_tools(active_agent)
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -130,7 +130,7 @@ def create_agent(
native_output_binding = ProviderStrategyBinding.from_schema_spec(
response_format.schema_spec
)
middleware_tools = [t for m in middleware for t in m.tools]
# Setup tools
if isinstance(tools, list):
# Extract builtin provider tools (dict format)
@@ -139,25 +139,28 @@ def create_agent(
# Add structured output tools to regular tools
structured_tools = [info.tool for info in structured_output_tools.values()]
all_tools = [t for m in middleware for t in m.tools] + regular_tools + structured_tools
all_tools = middleware_tools + regular_tools + structured_tools
if all_tools: # Only create ToolNode if we have tools
tool_node = ToolNode(tools=all_tools)
else:
tool_node = None
default_tools = regular_tools + builtin_tools + structured_tools
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
else:
# tools is ToolNode or None
tool_node = tools
if tool_node:
default_tools = list(tool_node.tools_by_name.values())
default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
# Update tool node to know about tools provided by middleware
all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
tool_node = ToolNode(all_tools)
# Add structured output tools
for info in structured_output_tools.values():
default_tools.append(info.tool)
else:
default_tools = (
list(structured_output_tools.values()) if structured_output_tools else []
)
) + middleware_tools
# validate middleware
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), (
@@ -400,8 +403,21 @@ def create_agent(
[first_node, "tools", END],
)
else:
# If no tools, just go to END from model
graph.add_edge(last_node, END)
if last_node == "model_request":
# If no tools, just go to END from model
graph.add_edge(last_node, END)
else:
# If after_model, then need to check for jump_to
_add_middleware_edge(
graph,
middleware_w_after[0].after_model,
f"{middleware_w_after[0].__class__.__name__}.after_model",
END,
first_node,
tools_available=tool_node is not None,
)
# Add middleware edges (same as before)
if middleware_w_before:
@@ -414,11 +430,16 @@ def create_agent(
first_node,
tools_available=tool_node is not None,
)
if middleware_w_modify_model_request:
first_modify = middleware_w_modify_model_request[0]
next_node = f"{first_modify.__class__.__name__}.modify_model_request"
else:
next_node = "model_request"
_add_middleware_edge(
graph,
middleware_w_before[-1].before_model,
f"{middleware_w_before[-1].__class__.__name__}.before_model",
"model_request",
next_node,
first_node,
tools_available=tool_node is not None,
)
@@ -530,7 +551,9 @@ def _add_middleware_edge(
def jump_edge(state: AgentState) -> str:
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
destinations = [default_destination, END]
destinations = [default_destination]
if END != default_destination:
destinations.append(END)
if tools_available:
destinations.append("tools")
if name != model_destination:

View File

@@ -0,0 +1,112 @@
import json
from collections.abc import Callable, Sequence
from dataclasses import asdict, is_dataclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from pydantic import BaseModel
StructuredResponseT = TypeVar("StructuredResponseT")
class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None
structured_response: StructuredResponseT | None = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
tools: list = []
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
rf = kwargs.get("response_format")
is_native = isinstance(rf, dict) and rf.get("type") == "json_schema"
if self.tool_calls:
if is_native:
tool_calls = (
self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
)
else:
tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
else:
tool_calls = []
if is_native and not tool_calls:
if isinstance(self.structured_response, BaseModel):
content_obj = self.structured_response.model_dump()
elif is_dataclass(self.structured_response):
content_obj = asdict(self.structured_response)
elif isinstance(self.structured_response, dict):
content_obj = self.structured_response
message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
else:
messages_string = "-".join([m.content for m in messages]) + str(kwargs) + str(self.tools)
message = AIMessage(
content=messages_string,
id=str(self.index),
tool_calls=tool_calls.copy(),
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
self.tools = tools
if len(tools) == 0:
msg = "Must provide at least one tool"
raise ValueError(msg)
tool_dicts = []
for tool in tools:
if isinstance(tool, dict):
tool_dicts.append(tool)
continue
if not isinstance(tool, BaseTool):
continue
# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)
return self.bind(tools=tool_dicts)

View File

@@ -0,0 +1,9 @@
from langchain.agents import create_agent
from langchain.agents.middleware.deepagents import DeepAgentMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[DeepAgentMiddleware()])
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)

View File

@@ -0,0 +1,24 @@
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain.agents import create_agent
from langchain.agents.middleware.rag import RAGMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
from langchain_core.retrievers import BaseRetriever
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
class FakeRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, *,
run_manager: CallbackManagerForRetrieverRun) -> list[
Document]:
return [Document(page_content="foo")]
model = FakeToolCallingModel()
middleware = RAGMiddleware.from_retriever(FakeRetriever(), "foo")
agent = create_agent(model, [], middleware=[middleware])
print(agent.get_graph())
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)

View File

@@ -0,0 +1,15 @@
from langchain.agents import create_agent
from langchain.agents.middleware.swarm import SwarmMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
model = FakeToolCallingModel(tool_calls=tool_calls)
subagents = [
{"name": "foo1", "description": "bar1", "prompt": "hi", "tools": []},
{"name": "foo2", "description": "bar1", "prompt": "bye", "tools": []}
]
middleware = SwarmMiddleware(agents=subagents, starting_agent="foo1")
agent = create_agent(model, [], middleware=[middleware])
print(agent.get_graph())
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)