mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-30 13:50:11 +00:00
feat(langchain): support for parallel (or interrupted) tool calls and structured output (#32980)
This enables parallel tool calling w/ a combo of
1. Standard and structured response tool calls
2. Deferred (requiring human approval / edits) tool calls and structured
response tool calls
Hard to unit test w/ HITL right now end to end, so here's a repro of
things working w/ an integration test:
```py
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.human_in_the_loop import HumanInTheLoopMiddleware
from langchain_openai import ChatOpenAI
class WeatherBaseModel(BaseModel):
temperature: float = Field(description="Temperature in fahrenheit")
condition: str = Field(description="Weather condition")
@tool
def add_numbers(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
checkpointer = InMemorySaver()
agent = create_agent(
model=model,
tools=[add_numbers],
response_format=WeatherBaseModel,
middleware=[HumanInTheLoopMiddleware(tool_configs={"add_numbers": True})],
)
agent = agent.compile(checkpointer=checkpointer)
# First invocation should be interrupted due to human-in-the-loop middleware
response = agent.invoke(
{
"messages": [
HumanMessage(
"Add 1 and 2, then return the weather forecast with temperature 72 and condition sunny."
)
]
},
config={"configurable": {"thread_id": "1"}},
)
interrupt_description = response["__interrupt__"][0].value[0]["description"]
print(interrupt_description)
"""
Tool execution requires approval
Tool: add_numbers
Args: {'a': 1, 'b': 2}
"""
# Resume the agent with approval
response = agent.invoke(
Command(resume=[{"type": "approve"}]), config={"configurable": {"thread_id": "1"}}
)
for msg in response["messages"]:
msg.pretty_print()
"""
================================ Human Message =================================
Add 1 and 2, then return the weather forecast with temperature 72 and condition sunny.
================================== Ai Message ==================================
Tool Calls:
WeatherBaseModel (call_u6nXsEYRJbqNx4AEHgiQMpE2)
Call ID: call_u6nXsEYRJbqNx4AEHgiQMpE2
Args:
temperature: 72
condition: sunny
add_numbers (call_nuQEZF7PwfYDlVpnSt8eaInI)
Call ID: call_nuQEZF7PwfYDlVpnSt8eaInI
Args:
a: 1
b: 2
================================= Tool Message =================================
Name: WeatherBaseModel
Returning structured response: temperature=72.0 condition='sunny'
================================= Tool Message =================================
Name: add_numbers
3
"""
print(repr(response["response"]))
"""
WeatherBaseModel(temperature=72.0, condition='sunny')
"""
```
This commit is contained in:
@@ -138,16 +138,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
last_ai_msg = next((msg for msg in messages if isinstance(msg, AIMessage)), None)
|
||||
if not last_ai_msg or not last_ai_msg.tool_calls:
|
||||
return None
|
||||
|
||||
# Separate tool calls that need interrupts from those that don't
|
||||
interrupt_tool_calls: dict[str, ToolCall] = {}
|
||||
auto_approved_tool_calls = []
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
for tool_call in last_ai_msg.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
if tool_name in self.tool_configs:
|
||||
# fix: id should not be typed as Optional on `langchain_core.messages.tool.ToolCall`
|
||||
@@ -259,9 +258,9 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
last_message.tool_calls = [*approved_tool_calls, *rejected_tool_calls]
|
||||
last_ai_msg.tool_calls = [*approved_tool_calls, *rejected_tool_calls]
|
||||
|
||||
if len(approved_tool_calls) > 0:
|
||||
return {"messages": [last_message, *artificial_tool_messages]}
|
||||
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
||||
|
||||
return {"jump_to": "model", "messages": artificial_tool_messages}
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Callable, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.constants import END, START
|
||||
@@ -219,8 +219,6 @@ def create_agent( # noqa: PLR0915
|
||||
if not output.tool_calls and native_output_binding:
|
||||
structured_response = native_output_binding.parse(output)
|
||||
return {"messages": [output], "response": structured_response}
|
||||
if state.get("response") is not None:
|
||||
return {"messages": [output], "response": None}
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tools strategy
|
||||
@@ -418,7 +416,7 @@ def create_agent( # noqa: PLR0915
|
||||
if tool_node is not None:
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(tool_node, first_node),
|
||||
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
||||
[first_node, END],
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
@@ -482,6 +480,22 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_last_ai_and_tool_messages(
|
||||
messages: list[AnyMessage],
|
||||
) -> tuple[AIMessage, list[ToolMessage]]:
|
||||
last_ai_index: int
|
||||
last_ai_message: AIMessage
|
||||
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], AIMessage):
|
||||
last_ai_index = i
|
||||
last_ai_message = cast("AIMessage", messages[i])
|
||||
break
|
||||
|
||||
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
|
||||
return last_ai_message, tool_messages
|
||||
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||
) -> Callable[[AgentState], str | list[Send] | None]:
|
||||
@@ -489,31 +503,7 @@ def _make_model_to_tools_edge(
|
||||
if jump_to := state.get("jump_to"):
|
||||
return _resolve_jump(jump_to, first_node)
|
||||
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
# Check if this is a ToolMessage from structured output - if so, end
|
||||
# interesting, should we be auto ending here? should we execute other tools?
|
||||
if isinstance(last_message, ToolMessage) and last_message.name in structured_output_tools:
|
||||
return END
|
||||
|
||||
# Find the last AI message and all tool messages since said AI message
|
||||
last_ai_index = None
|
||||
last_ai_message: AIMessage
|
||||
for i in range(len(state["messages"]) - 1, -1, -1):
|
||||
if isinstance(state["messages"][i], AIMessage):
|
||||
last_ai_index = i
|
||||
last_ai_message = cast("AIMessage", state["messages"][i])
|
||||
break
|
||||
|
||||
tool_messages = (
|
||||
[
|
||||
m.tool_call_id
|
||||
for m in state["messages"][last_ai_index + 1 :]
|
||||
if isinstance(m, ToolMessage)
|
||||
]
|
||||
if last_ai_index is not None
|
||||
else []
|
||||
)
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
|
||||
pending_tool_calls = [
|
||||
c
|
||||
@@ -538,17 +528,21 @@ def _make_model_to_tools_edge(
|
||||
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
tool_node: ToolNode, next_node: str
|
||||
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||
) -> Callable[[AgentState], str | None]:
|
||||
def tools_to_model(state: AgentState) -> str | None:
|
||||
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
|
||||
if all(
|
||||
tool_node.tools_by_name[c["name"]].return_direct
|
||||
for c in ai_message.tool_calls
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
):
|
||||
return END
|
||||
|
||||
if any(t.name in structured_output_tools for t in tool_messages):
|
||||
return END
|
||||
|
||||
return next_node
|
||||
|
||||
return tools_to_model
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -14,6 +15,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.human_in_the_loop import (
|
||||
@@ -28,7 +30,7 @@ from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.prebuilt.interrupt import ActionRequest
|
||||
from langchain.agents.structured_output import ToolStrategy
|
||||
|
||||
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from .model import FakeToolCallingModel
|
||||
@@ -1275,3 +1277,78 @@ def test_modify_model_request() -> None:
|
||||
assert (
|
||||
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
||||
)
|
||||
|
||||
|
||||
def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
|
||||
"""Test that when there are both structured and regular tool calls, we execute regular and jump to END."""
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
"""Weather response."""
|
||||
|
||||
temperature: float = Field(description="Temperature in fahrenheit")
|
||||
condition: str = Field(description="Weather condition")
|
||||
|
||||
@tool
|
||||
def regular_tool(query: str) -> str:
|
||||
"""A regular tool that returns a string."""
|
||||
return f"Regular tool result for: {query}"
|
||||
|
||||
# Create a fake model that returns both structured and regular tool calls
|
||||
class FakeModelWithBothToolCalls(FakeToolCallingModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tool_calls = [
|
||||
[
|
||||
ToolCall(
|
||||
name="WeatherResponse",
|
||||
args={"temperature": 72.0, "condition": "sunny"},
|
||||
id="structured_call_1",
|
||||
),
|
||||
ToolCall(
|
||||
name="regular_tool", args={"query": "test query"}, id="regular_call_1"
|
||||
),
|
||||
]
|
||||
]
|
||||
|
||||
# Create agent with both structured output and regular tools
|
||||
agent = create_agent(
|
||||
model=FakeModelWithBothToolCalls(),
|
||||
tools=[regular_tool],
|
||||
response_format=ToolStrategy(schema=WeatherResponse),
|
||||
)
|
||||
|
||||
# Compile and invoke the agent
|
||||
compiled_agent = agent.compile()
|
||||
result = compiled_agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
|
||||
)
|
||||
|
||||
# Verify that we have the expected messages:
|
||||
# 1. Human message
|
||||
# 2. AI message with both tool calls
|
||||
# 3. Tool message from structured tool call
|
||||
# 4. Tool message from regular tool call
|
||||
|
||||
messages = result["messages"]
|
||||
assert len(messages) >= 4
|
||||
|
||||
# Check that we have the AI message with both tool calls
|
||||
ai_message = messages[1]
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
assert len(ai_message.tool_calls) == 2
|
||||
|
||||
# Check that we have a tool message from the regular tool
|
||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) >= 1
|
||||
|
||||
# The regular tool should have been executed
|
||||
regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None)
|
||||
assert regular_tool_message is not None
|
||||
assert "Regular tool result for: test query" in regular_tool_message.content
|
||||
|
||||
# Verify that the structured response is available in the result
|
||||
assert "response" in result
|
||||
assert result["response"] is not None
|
||||
assert hasattr(result["response"], "temperature")
|
||||
assert result["response"].temperature == 72.0
|
||||
assert result["response"].condition == "sunny"
|
||||
|
||||
Reference in New Issue
Block a user