Compare commits

...

12 Commits

Author SHA1 Message Date
Harrison Chase
10fa55e287 cr 2025-09-18 10:42:43 -04:00
Harrison Chase
8289f09fcd Merge branch 'sr/new-agent' into harrison/new-middleware 2025-09-07 12:29:46 +02:00
Harrison Chase
a65fb5297e cr 2025-09-07 12:28:38 +02:00
Harrison Chase
dccfcd3f08 new middleware 2025-09-07 12:26:43 +02:00
Harrison Chase
44a60a6f09 cr 2025-09-06 13:11:02 -07:00
Sydney Runkle
26bef498e8 refactored summarization node 2025-09-05 16:53:19 -04:00
Sydney Runkle
4ead6355ba enhancing prompt caching 2025-09-05 14:37:06 -04:00
Sydney Runkle
1564a286dd model settings and prompt caching 2025-09-05 13:38:54 -04:00
Sydney Runkle
15230b47c6 adding support to existing agent 2025-09-05 11:16:56 -04:00
Sydney Runkle
a6a4b0d58f some linting 2025-09-05 10:16:21 -04:00
Sydney Runkle
522f99da34 initial middlewares + testing 2025-09-05 10:15:25 -04:00
Sydney Runkle
82d8a415e3 first pass at new agent w/ middleware 2025-09-05 10:02:51 -04:00
25 changed files with 2114 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
def _generate_correction_tool_messages(content: str, tool_calls: list):
tool_messages = []
for tool_call in tool_calls:
tool_messages.append({
"role": "tool",
"content": content,
"tool_call_id": tool_call["id"]
})
return tool_messages

View File

@@ -0,0 +1,68 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
from typing import NotRequired, Annotated
from typing import Literal
from typing_extensions import TypedDict
class Todo(TypedDict):
"""Todo to track."""
content: str
status: Literal["pending", "in_progress", "completed"]
def file_reducer(l, r):
if l is None:
return r
elif r is None:
return l
else:
return {**l, **r}
class DeepAgentState(AgentState):
todos: NotRequired[list[Todo]]
files: Annotated[NotRequired[dict[str, str]], file_reducer]
from langchain_core.tools import tool, InjectedToolCallId
from langgraph.types import Command
from langchain_core.messages import ToolMessage
from typing import Annotated, Union
from langgraph.prebuilt import InjectedState
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
"""write todos"""
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
],
}
)
def ls(state: Annotated[DeepAgentState, InjectedState]) -> list[str]:
"""List all files"""
return list(state.get("files", {}).keys())
class DeepAgentMiddleware(AgentMiddleware):
state_schema = DeepAgentState
def __init__(self, subagents: list = []):
self.subagents = subagents
@property
def tools(self):
return [write_todos, ls] + self.subagents
def modify_model_request(self, request: ModelRequest, state: DeepAgentState) -> ModelRequest:
if request.system_prompt:
request.system_prompt += "\n\nUse the todo tool to plan as needed"
else:
request.system_prompt = "Use the todo tool to plan as needed"
return request

View File

@@ -0,0 +1,22 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
class DynamicPrompt(AgentMiddleware):
def __init__(self, modifier):
self.modifier = modifier
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
prompt = self.modifier(state)
request.system_prompt = prompt
return request
class DynamicMessages(AgentMiddleware):
def __init__(self, modifier):
self.modifier = modifier
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
messages = self.modifier(state)
request.messages = messages
return request

View File

@@ -0,0 +1,53 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict
PROMPT = """Check if the conversation trips any of the guardrails. If it trips multiple, flag the guardrail that is violated the most
<conversation>
{conversation}
</conversation>
<guardrails>
{guardrails}
</guardrails>"""
class Guardrail(TypedDict):
name: str
prompt: str
response_str: str
class InputGuardrailMiddleware(AgentMiddleware):
def __init__(self, guardrails: list[Guardrail], model):
super().__init__()
self.guardrails = guardrails
self.model = model
def _convert_to_string(self, state: AgentState):
# TODO: improve
return str(state['messages'])
def before_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
conversation = self._convert_to_string(state)
guardrails = "\n".join([
f"<{guard['name']}>{guard['prompt']}</{guard['name']}>" for guard in self.guardrails
])
prompt = PROMPT.format(conversation=conversation, guardrails=guardrails)
class Response(TypedDict):
# todo: fix docstring
"""flagged should be one of {} or `none`"""
flagged: str
response = self.model.with_structured_output(Response).invoke(prompt)
if response['flagged'] == 'none':
return
else:
resp = {g['name']: g['response_str'] for g in self.guardrails}
return {
"messages": [{"role": 'ai', "content": resp}],
"jump_to": "__end__"
}

View File

@@ -0,0 +1,115 @@
from langgraph.prebuilt.interrupt import (
ActionRequest,
HumanInterrupt,
HumanInterruptConfig,
HumanResponse,
)
from langgraph.types import interrupt
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
from langchain.agents.middleware._utils import _generate_correction_tool_messages
ToolInterruptConfig = dict[str, HumanInterruptConfig]
class HumanInTheLoopMiddleware(AgentMiddleware):
def __init__(
self,
tool_configs: ToolInterruptConfig,
message_prefix: str = "Tool execution requires approval",
):
super().__init__()
self.tool_configs = tool_configs
self.message_prefix = message_prefix
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None
# Separate tool calls that need interrupts from those that don't
interrupt_tool_calls = []
auto_approved_tool_calls = []
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
if tool_name in self.tool_configs:
interrupt_tool_calls.append(tool_call)
else:
auto_approved_tool_calls.append(tool_call)
# If no interrupts needed, return early
if not interrupt_tool_calls:
return None
approved_tool_calls = auto_approved_tool_calls.copy()
# Right now, we do not support multiple tool calls with interrupts
if len(interrupt_tool_calls) > 1:
tool_names = [t['name'] for t in interrupt_tool_calls]
msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time"
return {
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
"jump_to": "model"
}
# Right now, we do not support interrupting a tool call if other tool calls exist
if auto_approved_tool_calls:
tool_names = [t['name'] for t in interrupt_tool_calls]
msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
# Only one tool call will need interrupts
tool_call = interrupt_tool_calls[0]
tool_name = tool_call["name"]
tool_args = tool_call["args"]
description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
tool_config = self.tool_configs[tool_name]
request: HumanInterrupt = {
"action_request": ActionRequest(
action=tool_name,
args=tool_args,
),
"config": tool_config,
"description": description,
}
responses: list[HumanResponse] = interrupt([request])
response = responses[0]
if response["type"] == "accept":
approved_tool_calls.append(tool_call)
elif response["type"] == "edit":
edited: ActionRequest = response["args"]
new_tool_call = {
"type": "tool_call",
"name": tool_call["name"],
"args": edited["args"],
"id": tool_call["id"],
}
approved_tool_calls.append(new_tool_call)
elif response["type"] == "ignore":
return {"jump_to": "__end__"}
elif response["type"] == "response":
tool_message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"content": response["args"],
}
return {"messages": [tool_message], "jump_to": "model"}
else:
raise ValueError(f"Unknown response type: {response['type']}")
last_message.tool_calls = approved_tool_calls
return {"messages": [last_message]}

View File

@@ -0,0 +1,145 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
_HANDBACK_NAME = "hand_back"
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str | None
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_pass_back_tool(self):
@tool(name_or_callable=_HANDBACK_NAME,
description="Call this if you are unable to handle the current request. You will hand back control of the conversation to your supervisor")
def hand_back():
pass
return hand_back
def _get_main_handoff_tools(self):
tools = []
for agent in self.agents:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
if state.get('active_agent') is None:
request.tools = request.tools + self._get_main_handoff_tools()
return request
active_agent = self.agent_mapping[state['active_agent']]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools'] + self._get_handoff_tool()
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
active_agent = state.get('active_agent')
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None
if active_agent is not None:
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] == _HANDBACK_NAME:
handoffs.append(tool_call)
if len(handoffs) == 0:
return None
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
return {
"messages": [{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed back to supervisor",
}],
"active_agent": None,
"jump_to": "model"
}
handoff_tools = self._get_main_handoff_tools()
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -0,0 +1,24 @@
import operator
from typing import Annotated
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
class State(AgentState):
model_request_count: Annotated[int, operator.add]
class ModelRequestLimitMiddleware(AgentMiddleware):
"""Terminates after N model requests"""
state_schema = State
def __init__(self, max_requests: int = 10):
self.max_requests = max_requests
def before_model(self, state: State) -> AgentUpdate | AgentJump | None:
# TODO: want to be able to configure end behavior here
if state.get("model_request_count", 0) == self.max_requests:
return {"jump_to": "__end__"}
return {"model_request_count": 1}

View File

@@ -0,0 +1,53 @@
from typing import Literal
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
class AnthropicPromptCachingMiddleware(AgentMiddleware):
"""Prompt Caching Middleware - Optimizes API usage by caching conversation prefixes for Anthropic models.
Learn more about anthropic prompt caching [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
"""
def __init__(
self,
type: Literal["ephemeral"] = "ephemeral",
ttl: Literal["5m", "1h"] = "5m",
min_messages_to_cache: int = 0,
):
"""Initialize the middleware with cache control settings.
Args:
type: The type of cache to use, only "ephemeral" is supported.
ttl: The time to live for the cache, only "5m" and "1h" are supported.
min_messages_to_cache: The minimum number of messages until the cache is used, default is 0.
"""
self.type = type
self.ttl = ttl
self.min_messages_to_cache = min_messages_to_cache
def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest:
"""Modify the model request to add cache control blocks."""
try:
from langchain_anthropic import ChatAnthropic
except ImportError:
raise ValueError(
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models."
"Please install langchain-anthropic."
)
if not isinstance(request.model, ChatAnthropic):
raise ValueError(
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, "
f"not instances of {type(request.model)}"
)
messages_count = (
len(request.messages) + 1 if request.system_prompt else len(request.messages)
)
if messages_count < self.min_messages_to_cache:
return request
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
return request

View File

@@ -0,0 +1,35 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentJump, AgentUpdate
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import tool
import uuid
class RAGMiddleware(AgentMiddleware):
@classmethod
def from_retriever(cls, retriever: BaseRetriever, description: str):
@tool(description=description)
def retrieve(query: str):
return retriever.get_relevant_documents(query)
return cls(retrieve)
def __init__(self, tool):
self.tool = tool
@property
def tools(self):
return [self.tool]
def before_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
if len(state['messages']) == 1:
forced_tool_call = {
"type": "tool_call",
"name": self.tool.name,
"args": {"query": state['messages'][0].content},
"id": str(uuid.uuid4()),
}
return {
"messages": [{"role": "ai", "content": None, "tool_calls": [forced_tool_call]}],
"jump_to": "tools"
}

View File

@@ -0,0 +1,14 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
class ReflectionMiddleware(AgentMiddleware):
def __init__(self, reflection_step):
self.reflection_step = reflection_step
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
reflection = self.reflection_step(state)
if reflection:
return {
"messages": [{'role': 'user', 'content': reflection}],
"jump_to": "model"
}

View File

@@ -0,0 +1,240 @@
import uuid
from collections.abc import Callable, Iterable
from typing import cast
from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
from langgraph.graph.message import REMOVE_ALL_MESSAGES, RemoveMessage
from langchain.agents.types import AgentMiddleware, AgentState
from langchain.chat_models import BaseChatModel
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
DEFAULT_SUMMARY_PROMPT = """<role>
Context Extraction Assistant
</role>
<primary_objective>
Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
</primary_objective>
<objective_information>
You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to your overall goal.
</objective_information>
<instructions>
The conversation history below will be replaced with the context you extract in this step. Because of this, you must do your very best to extract and record all of the most important context from the conversation history.
You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
</instructions>
The user will message you with the full message history you'll be extracting context from, to then replace. Carefully read over it all, and think deeply about what information is most important to your overall goal that should be saved:
With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
<messages>
Messages to summarize:
{messages}
</messages>"""
SUMMARY_PREFIX = "## Previous conversation summary:"
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
class SummarizationMiddleware(AgentMiddleware):
"""Middleware that summarizes conversation history when token limits are approached.
This middleware monitors message token counts and automatically summarizes older
messages when a threshold is reached, preserving recent messages and maintaining
context continuity by ensuring AI/Tool message pairs remain together.
"""
def __init__(
self,
model: BaseChatModel,
max_tokens_before_summary: int | None = None,
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
summary_prefix: str = SUMMARY_PREFIX,
):
"""Initialize the summarization middleware.
Args:
model: The language model to use for generating summaries.
max_tokens_before_summary: Token threshold to trigger summarization.
If None, summarization is disabled.
messages_to_keep: Number of recent messages to preserve after summarization.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
summary_prefix: Prefix added to system message when including summary.
"""
super().__init__()
self.model = model
self.max_tokens_before_summary = max_tokens_before_summary
self.messages_to_keep = messages_to_keep
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.summary_prefix = summary_prefix
def before_model(self, state: AgentState) -> AgentState | None:
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)
total_tokens = self.token_counter(messages)
if (
self.max_tokens_before_summary is not None
and total_tokens < self.max_tokens_before_summary
):
return None
cutoff_index = self._find_safe_cutoff(messages)
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(
messages, cutoff_index
)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages,
*preserved_messages,
]
}
def _build_new_messages(self, summary: str):
return [
{"role": "user", "content": f"Here is a summary of the conversation to date:\n\n{summary}"}
]
def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
"""Ensure all messages have unique IDs for the add_messages reducer."""
for msg in messages:
if msg.id is None:
msg.id = str(uuid.uuid4())
def _partition_messages(
self,
conversation_messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition messages into those to summarize and those to preserve."""
messages_to_summarize = conversation_messages[:cutoff_index]
preserved_messages = conversation_messages[cutoff_index:]
return messages_to_summarize, preserved_messages
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns 0 if no safe cutoff is found.
"""
if len(messages) <= self.messages_to_keep:
return 0
target_cutoff = len(messages) - self.messages_to_keep
for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
return 0
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
"""Check if cutting at index would separate AI/Tool message pairs."""
if cutoff_index >= len(messages):
return True
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
for i in range(search_start, search_end):
if not self._has_tool_calls(messages[i]):
continue
tool_call_ids = self._extract_tool_call_ids(messages[i])
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
return False
return True
def _has_tool_calls(self, message: AnyMessage) -> bool:
"""Check if message is an AI message with tool calls."""
return (
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls
)
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
"""Extract tool call IDs from an AI message."""
tool_call_ids = set()
for tc in ai_message.tool_calls:
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
if call_id is not None:
tool_call_ids.add(call_id)
return tool_call_ids
def _cutoff_separates_tool_pair(
self,
messages: list[AnyMessage],
ai_message_index: int,
cutoff_index: int,
tool_call_ids: set[str],
) -> bool:
"""Check if cutoff separates an AI message from its corresponding tool messages."""
for j in range(ai_message_index + 1, len(messages)):
message = messages[j]
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
ai_before_cutoff = ai_message_index < cutoff_index
tool_before_cutoff = j < cutoff_index
if ai_before_cutoff != tool_before_cutoff:
return True
return False
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
if not messages_to_summarize:
return "No previous conversation history."
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed_messages:
return "Previous conversation was too long to summarize."
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return cast("str", response.content).strip()
except Exception as e:
return f"Error generating summary: {e!s}"
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
"""Trim messages to fit within summary generation limits."""
try:
return trim_messages(
messages,
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
)
except Exception:
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]

View File

@@ -0,0 +1,129 @@
import uuid
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str | None
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_main_handoff_tools(self):
tools = []
for agent in self.agents:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
if state.get('active_agent') is None:
request.tools = request.tools + self._get_main_handoff_tools()
return request
active_agent = self.agent_mapping[state['active_agent']]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools']
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
messages = state["messages"]
active_agent = state.get('active_agent')
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
if active_agent is None:
return None
else:
fake_tool_call_id = str(uuid.uuid4())
last_message.tool_calls = [{
"type": "tool_call",
"name": "hand_back_to_supervisor",
"args": {},
"id": fake_tool_call_id,
}]
fake_tool_message = {"role": "tool", "content": "Handed back to supervisor", "tool_call_id": fake_tool_call_id}
return {
"messages": [last_message, fake_tool_message],
"jump_to": "model"
}
if active_agent is not None:
return None
handoff_tools = self._get_main_handoff_tools()
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -0,0 +1,111 @@
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
from typing_extensions import TypedDict, Type
from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.agents.middleware._utils import _generate_correction_tool_messages
class Agent(TypedDict):
name: str
description: str
prompt: str
tools: list
model: str
model_settings: dict
response_format: Type
class SwarmAgentState(AgentState):
active_agent: str
class SwarmMiddleware(AgentMiddleware):
state_schema = SwarmAgentState
def __init__(self, agents: list[Agent], starting_agent: str):
self.agents = agents
self.starting_agent = starting_agent
self.agent_mapping = {a['name']: a for a in agents}
@property
def tools(self):
return [t for a in self.agents for t in a['tools']]
def _get_handoff_tool(self, agent: Agent):
@tool(
name_or_callable=f"handoff_to_{agent['name']}",
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
)
def handoff():
pass
return handoff
def _get_handoff_tools(self, active_agent: str):
tools = []
for agent in self.agents:
if agent['name'] != active_agent:
tools.append(self._get_handoff_tool(agent))
return tools
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
active_agent_name = state.get('active_agent') or self.starting_agent
active_agent = self.agent_mapping[active_agent_name]
request.system_prompt = active_agent['prompt']
request.tools = active_agent['tools'] + self._get_handoff_tools(active_agent)
if 'model' in active_agent:
request.model = init_chat_model(active_agent['model'])
if 'model_settings' in active_agent:
request.model_settings = active_agent['model_settings']
if 'response_format' in active_agent:
request.response_format = active_agent['response_format']
return request
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
active_agent = state.get('active_agent') or self.starting_agent
messages = state["messages"]
if not messages:
return None
last_message = messages[-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
return None
handoff_tools = self._get_handoff_tools(active_agent)
handoff_tool_names = [t.name for t in handoff_tools]
handoffs = []
for tool_call in last_message.tool_calls:
if tool_call['name'] in handoff_tool_names:
handoffs.append(tool_call)
if len(handoffs) == 0:
return
elif len(handoffs) > 1:
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
return {
"messages": _generate_correction_tool_messages(msg,
last_message.tool_calls),
"jump_to": "model"
}
else:
tool_call = handoffs[0]
handoff_to = tool_call['name'][len("handoff_to_"):]
return {
"messages":[{
"role": "tool",
"tool_call_id": tool_call["id"],
"content": f"Handed off to agent {handoff_to}",
}],
"active_agent":handoff_to,
"jump_to": "model"
}

View File

@@ -0,0 +1,36 @@
from dataclasses import field
from typing import cast
from langchain_core.messages import AIMessage
from langgraph.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
class State(AgentState):
tool_call_count: dict[str, int] = field(default_factory=dict)
class ToolCallLimitMiddleware(AgentMiddleware):
"""Terminates after a specific tool is called N times"""
state_schema = State
def __init__(self, tool_limits: dict[str, int]):
self.tool_limits = tool_limits
def after_model(self, state: State) -> AgentUpdate | AgentJump | None:
ai_msg: AIMessage = cast("AIMessage", state["messages"][-1])
tool_calls = {}
for call in ai_msg.tool_calls or []:
tool_calls[call["name"]] = tool_calls.get(call["name"], 0) + 1
aggregate_calls = state["tool_call_count"].copy()
for tool_name in tool_calls:
aggregate_calls[tool_name] = aggregate_calls.get(tool_name, 0) + 1
for tool_name, max_calls in self.tool_limits.items():
count = aggregate_calls.get(tool_name, 0)
if count == max_calls:
return {"tool_call_count": aggregate_calls, "jump_to": "__end__"}
return {"tool_call_count": aggregate_calls}

View File

@@ -0,0 +1,564 @@
from collections.abc import Callable, Sequence
from inspect import signature
from typing import Any, Union, cast
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from typing_extensions import TypeVar
# Import structured output classes from the old implementation
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
OutputToolBinding,
ProviderStrategy,
ProviderStrategyBinding,
StructuredOutputValidationError,
ToolStrategy,
)
from langchain.agents.tool_node import ToolNode
from langchain.agents.types import (
AgentJump,
AgentMiddleware,
AgentState,
AgentUpdate,
JumpTo,
ModelRequest,
ResponseFormat,
)
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
def _supports_native_structured_output(model: Union[str, BaseChatModel]) -> bool:
"""Check if a model supports native structured output."""
model_name: str | None = None
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)
def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,
) -> tuple[bool, str]:
"""Handle structured output error. Returns (should_retry, retry_tool_message)."""
if not isinstance(response_format, ToolStrategy):
return False, ""
handle_errors = response_format.handle_errors
if handle_errors is False:
return False, ""
if handle_errors is True:
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
if isinstance(handle_errors, str):
return True, handle_errors
if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
if isinstance(exception, handle_errors):
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
return False, ""
if isinstance(handle_errors, tuple):
if any(isinstance(exception, exc_type) for exc_type in handle_errors):
return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
return False, ""
if callable(handle_errors):
return True, handle_errors(exception)
return False, ""
ContextT = TypeVar("ContextT")
ResponseT = TypeVar("ResponseT")
def create_agent(
*,
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
context_schema: type[ContextT] | None = None,
) -> StateGraph[AgentState[ResponseT], ContextT]:
# init chat model
if isinstance(model, str):
try:
from langchain.chat_models import ( # type: ignore[import-not-found]
init_chat_model,
)
except ImportError:
raise ImportError(
"Please install langchain (`pip install langchain`) to "
"use '<provider>:<model>' string syntax for `model` parameter."
)
model = cast("BaseChatModel", init_chat_model(model))
# Handle tools being None or empty
if tools is None:
tools = []
# Setup structured output
structured_output_tools: dict[str, OutputToolBinding] = {}
native_output_binding: ProviderStrategyBinding | None = None
if response_format is not None:
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
# Auto-detect strategy based on model capabilities
if _supports_native_structured_output(model):
response_format = ProviderStrategy(schema=response_format)
else:
response_format = ToolStrategy(schema=response_format)
if isinstance(response_format, ToolStrategy):
# Setup tools strategy for structured output
for response_schema in response_format.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
elif isinstance(response_format, ProviderStrategy):
# Setup native strategy
native_output_binding = ProviderStrategyBinding.from_schema_spec(
response_format.schema_spec
)
middleware_tools = [t for m in middleware for t in m.tools]
# Setup tools
if isinstance(tools, list):
# Extract builtin provider tools (dict format)
builtin_tools = [t for t in tools if isinstance(t, dict)]
regular_tools = [t for t in tools if not isinstance(t, dict)]
# Add structured output tools to regular tools
structured_tools = [info.tool for info in structured_output_tools.values()]
all_tools = middleware_tools + regular_tools + structured_tools
if all_tools: # Only create ToolNode if we have tools
tool_node = ToolNode(tools=all_tools)
else:
tool_node = None
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
else:
# tools is ToolNode or None
tool_node = tools
if tool_node:
default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
# Update tool node to know about tools provided by middleware
all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
tool_node = ToolNode(all_tools)
# Add structured output tools
for info in structured_output_tools.values():
default_tools.append(info.tool)
else:
default_tools = (
list(structured_output_tools.values()) if structured_output_tools else []
) + middleware_tools
# validate middleware
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), (
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
]
middleware_w_modify_model_request = [
m
for m in middleware
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
]
middleware_w_after = [
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
]
# create graph, add nodes
graph = StateGraph(
AgentState,
input_schema=AgentUpdate,
output_schema=AgentUpdate,
context_schema=context_schema,
)
def _prepare_model_request(state: AgentState) -> tuple[ModelRequest, list[BaseMessage]]:
"""Prepare model request and messages."""
request = state.get("model_request") or ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
# prepare messages
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt)] + messages
return request, messages
def _handle_model_output(
state: AgentState, output: AIMessage, request: ModelRequest
) -> AgentState:
"""Handle model output including structured responses."""
# Handle structured output with native strategy
if isinstance(response_format, ProviderStrategy):
if not output.tool_calls and native_output_binding:
structured_response = native_output_binding.parse(output)
return {"messages": output, "response": structured_response}
if state.get("response") is not None:
return {"messages": output, "response": None}
return {"messages": output}
# Handle structured output with tools strategy
if isinstance(response_format, ToolStrategy):
if isinstance(output, AIMessage) and output.tool_calls:
structured_tool_calls = [
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
]
if structured_tool_calls:
if len(structured_tool_calls) > 1:
# Handle multiple structured outputs error
tool_names = [tc["name"] for tc in structured_tool_calls]
exception = MultipleStructuredOutputsError(tool_names)
should_retry, error_message = _handle_structured_output_error(
exception, response_format
)
if not should_retry:
raise exception
# Add error messages and retry
tool_messages = [
ToolMessage(
content=error_message,
tool_call_id=tc["id"],
name=tc["name"],
)
for tc in structured_tool_calls
]
return {"messages": [output] + tool_messages}
# Handle single structured output
tool_call = structured_tool_calls[0]
try:
structured_tool_binding = structured_output_tools[tool_call["name"]]
structured_response = structured_tool_binding.parse(tool_call["args"])
tool_message_content = (
response_format.tool_message_content
if response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
return {
"messages": [
output,
ToolMessage(
content=tool_message_content,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
"response": structured_response,
}
except Exception as exc:
exception = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
exception, response_format
)
if not should_retry:
raise exception
return {
"messages": [
output,
ToolMessage(
content=error_message,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
}
# Standard response handling
if state.get("response") is not None:
return {"messages": output, "response": None}
return {"messages": output}
def _get_bound_model(request: ModelRequest) -> BaseChatModel:
"""Get the model with appropriate tool bindings."""
if isinstance(response_format, ProviderStrategy):
# Use native structured output
kwargs = response_format.to_model_kwargs()
return request.model.bind_tools(
request.tools, strict=True, **kwargs, **request.model_settings
)
if isinstance(response_format, ToolStrategy):
tool_choice = "any" if structured_output_tools else request.tool_choice
return request.model.bind_tools(
request.tools, tool_choice=tool_choice, **request.model_settings
)
# Standard model binding
if request.tools:
return request.model.bind_tools(
request.tools, tool_choice=request.tool_choice, **request.model_settings
)
return request.model.bind(**request.model_settings)
def model_request(state: AgentState) -> AgentState:
"""Sync model request handler."""
request, messages = _prepare_model_request(state)
model_ = _get_bound_model(request)
output = model_.invoke(messages)
return _handle_model_output(state, output, request)
async def amodel_request(state: AgentState) -> AgentState:
"""Async model request handler."""
request, messages = _prepare_model_request(state)
model_ = _get_bound_model(request)
output = await model_.ainvoke(messages)
return _handle_model_output(state, output, request)
# Use sync or async based on model capabilities
from langgraph._internal._runnable import RunnableCallable
graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
# Only add tools node if we have tools
if tool_node is not None:
graph.add_node("tools", tool_node)
# Add middleware nodes
for m in middleware:
if m.__class__.before_model is not AgentMiddleware.before_model:
graph.add_node(
f"{m.__class__.__name__}.before_model",
m.before_model,
input_schema=m.state_schema,
)
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
def modify_model_request_node(state: AgentState) -> dict[str, ModelRequest]:
default_model_request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
return {
"model_request": m.modify_model_request(
state.get("model_request") or default_model_request, state
)
}
graph.add_node(
f"{m.__class__.__name__}.modify_model_request",
modify_model_request_node,
input_schema=m.state_schema,
)
if m.__class__.after_model is not AgentMiddleware.after_model:
graph.add_node(
f"{m.__class__.__name__}.after_model",
m.after_model,
input_schema=m.state_schema,
)
# add start edge
first_node = (
f"{middleware_w_before[0].__class__.__name__}.before_model"
if middleware_w_before
else f"{middleware_w_modify_model_request[0].__class__.__name__}.modify_model_request"
if middleware_w_modify_model_request
else "model_request"
)
last_node = (
f"{middleware_w_after[0].__class__.__name__}.after_model"
if middleware_w_after
else "model_request"
)
graph.add_edge(START, first_node)
# add conditional edges only if tools exist
if tool_node is not None:
graph.add_conditional_edges(
"tools",
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
[first_node, END],
)
graph.add_conditional_edges(
last_node,
_make_model_to_tools_edge(first_node, structured_output_tools),
[first_node, "tools", END],
)
else:
if last_node == "model_request":
# If no tools, just go to END from model
graph.add_edge(last_node, END)
else:
# If after_model, then need to check for jump_to
_add_middleware_edge(
graph,
middleware_w_after[0].after_model,
f"{middleware_w_after[0].__class__.__name__}.after_model",
END,
first_node,
tools_available=tool_node is not None,
)
# Add middleware edges (same as before)
if middleware_w_before:
for m1, m2 in zip(middleware_w_before, middleware_w_before[1:], strict=False):
_add_middleware_edge(
graph,
m1.before_model,
f"{m1.__class__.__name__}.before_model",
f"{m2.__class__.__name__}.before_model",
first_node,
tools_available=tool_node is not None,
)
if middleware_w_modify_model_request:
first_modify = middleware_w_modify_model_request[0]
next_node = f"{first_modify.__class__.__name__}.modify_model_request"
else:
next_node = "model_request"
_add_middleware_edge(
graph,
middleware_w_before[-1].before_model,
f"{middleware_w_before[-1].__class__.__name__}.before_model",
next_node,
first_node,
tools_available=tool_node is not None,
)
if middleware_w_modify_model_request:
for m1, m2 in zip(
middleware_w_modify_model_request, middleware_w_modify_model_request[1:], strict=False
):
_add_middleware_edge(
graph,
m1.modify_model_request,
f"{m1.__class__.__name__}.modify_model_request",
f"{m2.__class__.__name__}.modify_model_request",
first_node,
tools_available=tool_node is not None,
)
_add_middleware_edge(
graph,
middleware_w_modify_model_request[-1].modify_model_request,
f"{middleware_w_modify_model_request[-1].__class__.__name__}.modify_model_request",
"model_request",
first_node,
tools_available=tool_node is not None,
)
if middleware_w_after:
graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
for idx in range(len(middleware_w_after) - 1, 0, -1):
m1 = middleware_w_after[idx]
m2 = middleware_w_after[idx - 1]
_add_middleware_edge(
graph,
m1.after_model,
f"{m1.__class__.__name__}.after_model",
f"{m2.__class__.__name__}.after_model",
first_node,
tools_available=tool_node is not None,
)
return graph
def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
if jump_to == "model":
return first_node
if jump_to:
return jump_to
def _make_model_to_tools_edge(
first_node: str, structured_output_tools: dict[str, OutputToolBinding]
) -> Callable[[AgentState], str | None]:
def model_to_tools(state: AgentState) -> str | None:
if jump_to := state.get("jump_to"):
return _resolve_jump(jump_to, first_node)
message = state["messages"][-1]
# Check if this is a ToolMessage from structured output - if so, end
if isinstance(message, ToolMessage) and message.name in structured_output_tools:
return END
# Check for tool calls
if isinstance(message, AIMessage) and message.tool_calls:
# If all tool calls are for structured output, don't go to tools
non_structured_calls = [
tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
]
if non_structured_calls:
return "tools"
return END
return model_to_tools
def _make_tools_to_model_edge(
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
) -> Callable[[AgentState], str | None]:
def tools_to_model(state: AgentState) -> str | None:
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
if all(
tool_node.tools_by_name[c["name"]].return_direct
for c in ai_message.tool_calls
if c["name"] in tool_node.tools_by_name
):
return END
return next_node
return tools_to_model
def _add_middleware_edge(
graph: StateGraph,
method: Callable[[AgentState], AgentUpdate | AgentJump | None],
name: str,
default_destination: str,
model_destination: str,
tools_available: bool,
) -> None:
sig = signature(method)
uses_jump = sig.return_annotation is AgentJump or AgentJump in getattr(
sig.return_annotation, "__args__", ()
)
if uses_jump:
def jump_edge(state: AgentState) -> str:
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
destinations = [default_destination]
if END != default_destination:
destinations.append(END)
if tools_available:
destinations.append("tools")
if name != model_destination:
destinations.append(model_destination)
graph.add_conditional_edges(name, jump_edge, destinations)
else:
graph.add_edge(name, default_destination)

View File

@@ -45,6 +45,7 @@ from langgraph.typing import ContextT, StateT
from pydantic import BaseModel
from typing_extensions import NotRequired, TypedDict, TypeVar
from langchain.agents.middleware_agent import create_agent as create_middleware_agent
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
OutputToolBinding,
@@ -55,6 +56,7 @@ from langchain.agents.structured_output import (
ToolStrategy,
)
from langchain.agents.tool_node import ToolNode
from langchain.agents.types import AgentMiddleware
from langchain.chat_models import init_chat_model
if TYPE_CHECKING:
@@ -908,6 +910,7 @@ def create_agent( # noqa: D417
],
tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
*,
middleware: Sequence[AgentMiddleware] = (),
prompt: Prompt | None = None,
response_format: Union[
ToolStrategy[StructuredResponseT],
@@ -1114,6 +1117,29 @@ def create_agent( # noqa: D417
print(chunk)
```
"""
if middleware:
assert isinstance(model, str | BaseChatModel)
assert isinstance(prompt, str | None)
assert not isinstance(response_format, tuple)
assert pre_model_hook is None
assert post_model_hook is None
assert state_schema is None
return create_middleware_agent( # type: ignore[return-value]
model=model,
tools=tools,
system_prompt=prompt,
middleware=middleware,
response_format=response_format,
context_schema=context_schema,
).compile(
checkpointer=checkpointer,
store=store,
name=name,
interrupt_after=interrupt_after,
interrupt_before=interrupt_before,
debug=debug,
)
# Handle deprecated config_schema parameter
if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING:
warn(

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Annotated, Any, ClassVar, Generic, Literal
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AnyMessage
from langchain_core.tools import BaseTool
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import Messages, add_messages
from typing_extensions import TypedDict, TypeVar
from langchain.agents.structured_output import ResponseFormat
JumpTo = Literal["tools", "model", "__end__"]
ResponseT = TypeVar("ResponseT")
@dataclass
class ModelRequest:
model: BaseChatModel
system_prompt: str
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any
tools: list[BaseTool]
response_format: ResponseFormat | None
model_settings: dict[str, Any] = field(default_factory=dict)
class AgentState(TypedDict, Generic[ResponseT], total=False):
# TODO: import change allowing for required / not required and still registering reducer properly
# do we want to use total = False or require NotRequired?
messages: Annotated[list[AnyMessage], add_messages]
model_request: Annotated[ModelRequest | None, EphemeralValue]
jump_to: Annotated[JumpTo | None, EphemeralValue]
# TODO: structured response maybe?
response: ResponseT
StateT = TypeVar("StateT", bound=AgentState)
class AgentMiddleware(Generic[StateT]):
# TODO: I thought this should be a ClassVar[type[StateT]] but inherently class vars can't use type vars
# bc they're instance dependent
state_schema: ClassVar[type] = AgentState
tools: list[BaseTool] = []
def before_model(self, state: StateT) -> AgentUpdate | AgentJump | None:
pass
def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest:
return request
def after_model(self, state: StateT) -> AgentUpdate | AgentJump | None:
pass
class AgentUpdate(TypedDict, total=False):
messages: Messages
response: dict
class AgentJump(TypedDict, total=False):
messages: Messages
jump_to: JumpTo

View File

@@ -0,0 +1,112 @@
import json
from collections.abc import Callable, Sequence
from dataclasses import asdict, is_dataclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from pydantic import BaseModel
StructuredResponseT = TypeVar("StructuredResponseT")
class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None
structured_response: StructuredResponseT | None = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
tools: list = []
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
rf = kwargs.get("response_format")
is_native = isinstance(rf, dict) and rf.get("type") == "json_schema"
if self.tool_calls:
if is_native:
tool_calls = (
self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
)
else:
tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
else:
tool_calls = []
if is_native and not tool_calls:
if isinstance(self.structured_response, BaseModel):
content_obj = self.structured_response.model_dump()
elif is_dataclass(self.structured_response):
content_obj = asdict(self.structured_response)
elif isinstance(self.structured_response, dict):
content_obj = self.structured_response
message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
else:
messages_string = "-".join([m.content for m in messages]) + str(kwargs) + str(self.tools)
message = AIMessage(
content=messages_string,
id=str(self.index),
tool_calls=tool_calls.copy(),
)
self.index += 1
return ChatResult(generations=[ChatGeneration(message=message)])
@property
def _llm_type(self) -> str:
return "fake-tool-call-model"
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
self.tools = tools
if len(tools) == 0:
msg = "Must provide at least one tool"
raise ValueError(msg)
tool_dicts = []
for tool in tools:
if isinstance(tool, dict):
tool_dicts.append(tool)
continue
if not isinstance(tool, BaseTool):
continue
# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)
return self.bind(tools=tool_dicts)

View File

@@ -0,0 +1,9 @@
from langchain.agents import create_agent
from langchain.agents.middleware.deepagents import DeepAgentMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
model = FakeToolCallingModel()
agent = create_agent(model, [], middleware=[DeepAgentMiddleware()])
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)

View File

@@ -0,0 +1,24 @@
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain.agents import create_agent
from langchain.agents.middleware.rag import RAGMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
from langchain_core.retrievers import BaseRetriever
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
class FakeRetriever(BaseRetriever):
def _get_relevant_documents(self, query: str, *,
run_manager: CallbackManagerForRetrieverRun) -> list[
Document]:
return [Document(page_content="foo")]
model = FakeToolCallingModel()
middleware = RAGMiddleware.from_retriever(FakeRetriever(), "foo")
agent = create_agent(model, [], middleware=[middleware])
print(agent.get_graph())
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)

View File

@@ -0,0 +1,15 @@
from langchain.agents import create_agent
from langchain.agents.middleware.swarm import SwarmMiddleware
from langchain.chat_models.fake import FakeToolCallingModel
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
model = FakeToolCallingModel(tool_calls=tool_calls)
subagents = [
{"name": "foo1", "description": "bar1", "prompt": "hi", "tools": []},
{"name": "foo2", "description": "bar1", "prompt": "bye", "tools": []}
]
middleware = SwarmMiddleware(agents=subagents, starting_agent="foo1")
agent = create_agent(model, [], middleware=[middleware])
print(agent.get_graph())
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
print(s)

49
test_summarization.py Normal file
View File

@@ -0,0 +1,49 @@
from langchain.agents.middleware_agent import create_agent
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.middleware.summarization import SummarizationMiddleware
agent = create_agent(
model=ChatOpenAI(model="gpt-4.1-mini"),
system_prompt="You are a helpful assistant. Please reply nicely.",
middleware=[
SummarizationMiddleware(
model=ChatOpenAI(model="gpt-4.1-mini"), messages_to_keep=3
)
],
)
agent = agent.compile(checkpointer=InMemorySaver())
config: RunnableConfig = {"configurable": {"thread_id": "long_convo"}}
config = {"configurable": {"thread_id": "1"}}
agent.invoke({"messages": "hi, my name is bob"}, config)
agent.invoke({"messages": "my favorite food is pizza"}, config)
agent.invoke({"messages": "my favorite color is blue"}, config)
agent.invoke({"messages": "my favorite animal is a dog"}, config)
final_response = agent.invoke({"messages": "what's my name?"}, config)
for msg in final_response["messages"]:
msg.pretty_print()
"""
================================ System Message ================================
## Previous conversation summary:
User name: Bob. User's favorite food is pizza. User's favorite color is blue.
================================ Human Message =================================
my favorite animal is a dog
================================== Ai Message ==================================
Dogs are wonderful companions, Bob! Do you have a favorite breed, or maybe a dog of your own?
================================ Human Message =================================
what's my name?
================================== Ai Message ==================================
Your name is Bob! How can I assist you today, Bob?
"""

31
testing_caching.py Normal file
View File

@@ -0,0 +1,31 @@
from langchain_anthropic import ChatAnthropic
from langchain.agents.middleware.prompt_caching import AnthropicPromptCachingMiddleware
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import InMemorySaver
LONG_PROMPT = """
Please be a helpful assistant.
""" + "a" * (100 * 60) # 100 chars per line * 60 lines
agent = create_agent(
model=ChatAnthropic(model="claude-sonnet-4-20250514"),
tools=[],
prompt=LONG_PROMPT,
middleware=[AnthropicPromptCachingMiddleware(type="ephemeral", ttl="5m", min_messages_to_cache=3)],
checkpointer=InMemorySaver(),
)
config = {"configurable": {"thread_id": "abc"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, config)
agent.invoke({"messages": [HumanMessage("Hello")]}, config)
result3 = agent.invoke({"messages": [HumanMessage("Hello")]}, config)
for msg in result3["messages"]:
msg.pretty_print()
if isinstance(msg, AIMessage):
print(f"usage: {msg.response_metadata['usage']}")

163
testing_middleware.py Normal file
View File

@@ -0,0 +1,163 @@
from langchain.agents.middleware_agent import create_agent
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
import operator
from dataclasses import dataclass
from typing import Annotated
from pydantic import BaseModel
from langchain.agents.structured_output import ToolStrategy
from langchain.agents.middleware.model_call_limits import ModelRequestLimitMiddleware
@tool
def get_weather(city: str) -> str:
"""Get the weather for a given city"""
return f"The weather in {city} is sunny."
class WeatherResponse(BaseModel):
city: str
weather: str
# state extension (note we only make 3 tool calls below)
agent = create_agent(
model=ChatOpenAI(model="gpt-4o"),
tools=[get_weather],
system_prompt="You are a weather agent. You are tasked with fetching the weather for a given city.",
middleware=[ModelRequestLimitMiddleware(max_requests=3)],
)
agent = agent.compile()
result = agent.invoke(
{
"messages": [
HumanMessage(content="Please check the weather in SF, NYC, LA, and Boston.")
]
}
)
for msg in result["messages"]:
msg.pretty_print()
"""
================================ Human Message =================================
Please check the weather in SF, NYC, LA, and Boston.
================================== Ai Message ==================================
Tool Calls:
get_weather (call_7LddqyVgqxjTYm84UUfFBFZA)
Call ID: call_7LddqyVgqxjTYm84UUfFBFZA
Args:
city: San Francisco
================================= Tool Message =================================
Name: get_weather
The weather in San Francisco is sunny.
================================== Ai Message ==================================
Tool Calls:
get_weather (call_gUL7CHn6YqE80M9M5G5miA3k)
Call ID: call_gUL7CHn6YqE80M9M5G5miA3k
Args:
city: New York City
================================= Tool Message =================================
Name: get_weather
The weather in New York City is sunny.
================================== Ai Message ==================================
Tool Calls:
get_weather (call_asOAXRkPbBWBdt4SzQGPYQab)
Call ID: call_asOAXRkPbBWBdt4SzQGPYQab
Args:
city: Los Angeles
================================= Tool Message =================================
Name: get_weather
The weather in Los Angeles is sunny.
"""
# structured response
agent = create_agent(
model=ChatOpenAI(model="gpt-4o"),
tools=[get_weather],
system_prompt="You are a weather agent. You are tasked with fetching the weather for a given city.",
middleware=[ModelRequestLimitMiddleware(max_requests=3)],
response_format=ToolStrategy(WeatherResponse),
)
agent = agent.compile()
result = agent.invoke(
{
"messages": [
HumanMessage(content="Please check the weather in SF")
]
}
)
print(repr(result["response"]))
#> WeatherResponse(city='SF', weather='sunny')
# builtin provider tool support (web search for ex)
agent = create_agent(
model=ChatOpenAI(model="gpt-4o"),
tools=[{"type": "web_search_preview"}],
system_prompt="You are a weather agent. You are tasked with fetching the weather for a given city. Please use the web search tool to fetch the weather.",
# response_format=WeatherResponse,
)
agent = agent.compile()
result = agent.invoke(
{
"messages": [
HumanMessage(content="What is the weather in SF?")
]
}
)
for msg in result["messages"]:
msg.pretty_print()
"""
================================ Human Message =================================
What is the weather in SF?
================================== Ai Message ==================================
[{'type': 'text', 'text': 'As of 1:58 PM PDT on Friday, September 5, 2025, the weather in San Francisco, CA, is mostly cloudy with a temperature of 66°F (19°C). ([weather.com](https://weather.com/weather/today/l/San%2BFrancisco%2BCA?canonicalCityId=e7784799733d2133bcb75674a102b347&utm_source=openai))\n\n## Weather for San Francisco, CA:\nCurrent Conditions: Cloudy, 58°F (14°C)\n\nDaily Forecast:\n* Friday, September 5: Low: 60°F (15°C), High: 69°F (20°C), Description: Low clouds breaking for some sun\n* Saturday, September 6: Low: 61°F (16°C), High: 69°F (21°C), Description: Areas of low clouds, then sun and pleasant\n* Sunday, September 7: Low: 63°F (17°C), High: 72°F (22°C), Description: Areas of low clouds, then sun and pleasant\n* Monday, September 8: Low: 63°F (17°C), High: 71°F (21°C), Description: Low clouds breaking for some sun\n* Tuesday, September 9: Low: 60°F (16°C), High: 70°F (21°C), Description: Morning low clouds followed by clouds giving way to some sun\n* Wednesday, September 10: Low: 56°F (13°C), High: 68°F (20°C), Description: Mostly cloudy with a shower in places\n* Thursday, September 11: Low: 56°F (13°C), High: 69°F (21°C), Description: Partly sunny\n ', 'annotations': [{'end_index': 274, 'start_index': 134, 'title': 'Weather Forecast and Conditions for San Francisco, CA - The Weather Channel | Weather.com', 'type': 'url_citation', 'url': 'https://weather.com/weather/today/l/San%2BFrancisco%2BCA?canonicalCityId=e7784799733d2133bcb75674a102b347&utm_source=openai'}]}]
"""
# system prompt and tools as None
agent = create_agent(
model=ChatOpenAI(model="gpt-4o"),
tools=None,
system_prompt=None,
middleware=[ModelRequestLimitMiddleware(max_requests=3)],
)
agent = agent.compile()
result = agent.invoke(
{
"messages": [
HumanMessage(content="What is 2 + 2?")
]
}
)
result["messages"][-1].pretty_print()
"""
================================== Ai Message ==================================
2 + 2 equals 4.
"""
# a call and call model