mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
new middleware
This commit is contained in:
9
libs/langchain_v1/langchain/agents/middleware/_utils.py
Normal file
9
libs/langchain_v1/langchain/agents/middleware/_utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def _generate_correction_tool_messages(content: str, tool_calls: list):
|
||||
tool_messages = []
|
||||
for tool_call in tool_calls:
|
||||
tool_messages.append({
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": tool_call["id"]
|
||||
})
|
||||
return tool_messages
|
68
libs/langchain_v1/langchain/agents/middleware/deepagents.py
Normal file
68
libs/langchain_v1/langchain/agents/middleware/deepagents.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from typing import NotRequired, Annotated
|
||||
from typing import Literal
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class Todo(TypedDict):
|
||||
"""Todo to track."""
|
||||
|
||||
content: str
|
||||
status: Literal["pending", "in_progress", "completed"]
|
||||
|
||||
|
||||
def file_reducer(l, r):
|
||||
if l is None:
|
||||
return r
|
||||
elif r is None:
|
||||
return l
|
||||
else:
|
||||
return {**l, **r}
|
||||
|
||||
|
||||
class DeepAgentState(AgentState):
|
||||
todos: NotRequired[list[Todo]]
|
||||
files: Annotated[NotRequired[dict[str, str]], file_reducer]
|
||||
|
||||
|
||||
from langchain_core.tools import tool, InjectedToolCallId
|
||||
from langgraph.types import Command
|
||||
from langchain_core.messages import ToolMessage
|
||||
from typing import Annotated, Union
|
||||
from langgraph.prebuilt import InjectedState
|
||||
|
||||
def write_todos(
|
||||
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
||||
) -> Command:
|
||||
"""write todos"""
|
||||
return Command(
|
||||
update={
|
||||
"todos": todos,
|
||||
"messages": [
|
||||
ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def ls(state: Annotated[DeepAgentState, InjectedState]) -> list[str]:
|
||||
"""List all files"""
|
||||
return list(state.get("files", {}).keys())
|
||||
|
||||
class DeepAgentMiddleware(AgentMiddleware):
|
||||
|
||||
state_schema = DeepAgentState
|
||||
|
||||
def __init__(self, subagents: list = []):
|
||||
self.subagents = subagents
|
||||
|
||||
@property
|
||||
def tools(self):
|
||||
return [write_todos, ls] + self.subagents
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: DeepAgentState) -> ModelRequest:
|
||||
if request.system_prompt:
|
||||
request.system_prompt += "\n\nUse the todo tool to plan as needed"
|
||||
else:
|
||||
request.system_prompt = "Use the todo tool to plan as needed"
|
||||
return request
|
@@ -0,0 +1,22 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest
|
||||
|
||||
class DynamicPrompt(AgentMiddleware):
|
||||
|
||||
def __init__(self, modifier):
|
||||
self.modifier = modifier
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
|
||||
prompt = self.modifier(state)
|
||||
request.system_prompt = prompt
|
||||
return request
|
||||
|
||||
|
||||
class DynamicMessages(AgentMiddleware):
|
||||
|
||||
def __init__(self, modifier):
|
||||
self.modifier = modifier
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state) -> ModelRequest:
|
||||
messages = self.modifier(state)
|
||||
request.messages = messages
|
||||
return request
|
@@ -7,6 +7,7 @@ from langgraph.prebuilt.interrupt import (
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
|
||||
ToolInterruptConfig = dict[str, HumanInterruptConfig]
|
||||
|
||||
@@ -50,11 +51,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
|
||||
# Right now, we do not support multiple tool calls with interrupts
|
||||
if len(interrupt_tool_calls) > 1:
|
||||
raise ValueError("Does not currently support multiple tool calls with interrupts")
|
||||
tool_names = [t['name'] for t in interrupt_tool_calls]
|
||||
msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time"
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
|
||||
# Right now, we do not support interrupting a tool call if other tool calls exist
|
||||
if auto_approved_tool_calls:
|
||||
raise ValueError("Does not currently support interrupting a tool call if other tool calls exist")
|
||||
tool_names = [t['name'] for t in interrupt_tool_calls]
|
||||
msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg,
|
||||
last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
|
||||
# Only one tool call will need interrupts
|
||||
tool_call = interrupt_tool_calls[0]
|
||||
@@ -87,14 +99,14 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
||||
}
|
||||
approved_tool_calls.append(new_tool_call)
|
||||
elif response["type"] == "ignore":
|
||||
return {"goto": "__end__"}
|
||||
return {"jump_to": "__end__"}
|
||||
elif response["type"] == "response":
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": response["args"],
|
||||
}
|
||||
return {"messages": [tool_message], "goto": "model"}
|
||||
return {"messages": [tool_message], "jump_to": "model"}
|
||||
else:
|
||||
raise ValueError(f"Unknown response type: {response['type']}")
|
||||
|
||||
|
@@ -0,0 +1,145 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
|
||||
from typing_extensions import TypedDict, Type
|
||||
from langchain.tools import tool
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
|
||||
_HANDBACK_NAME = "hand_back"
|
||||
|
||||
|
||||
class Agent(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
prompt: str
|
||||
tools: list
|
||||
model: str
|
||||
model_settings: dict
|
||||
response_format: Type
|
||||
|
||||
|
||||
class SwarmAgentState(AgentState):
|
||||
active_agent: str | None
|
||||
|
||||
|
||||
class SwarmMiddleware(AgentMiddleware):
|
||||
|
||||
state_schema = SwarmAgentState
|
||||
|
||||
def __init__(self, agents: list[Agent], starting_agent: str):
|
||||
self.agents = agents
|
||||
self.starting_agent = starting_agent
|
||||
self.agent_mapping = {a['name']: a for a in agents}
|
||||
|
||||
@property
|
||||
def tools(self):
|
||||
return [t for a in self.agents for t in a['tools']]
|
||||
|
||||
def _get_handoff_tool(self, agent: Agent):
|
||||
@tool(
|
||||
name_or_callable=f"handoff_to_{agent['name']}",
|
||||
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
|
||||
)
|
||||
def handoff():
|
||||
pass
|
||||
|
||||
return handoff
|
||||
|
||||
def _get_pass_back_tool(self):
|
||||
@tool(name_or_callable=_HANDBACK_NAME,
|
||||
description="Call this if you are unable to handle the current request. You will hand back control of the conversation to your supervisor")
|
||||
def hand_back():
|
||||
pass
|
||||
|
||||
return hand_back
|
||||
|
||||
|
||||
def _get_main_handoff_tools(self):
|
||||
tools = []
|
||||
for agent in self.agents:
|
||||
tools.append(self._get_handoff_tool(agent))
|
||||
return tools
|
||||
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
|
||||
if state.get('active_agent') is None:
|
||||
request.tools = request.tools + self._get_main_handoff_tools()
|
||||
return request
|
||||
active_agent = self.agent_mapping[state['active_agent']]
|
||||
request.system_prompt = active_agent['prompt']
|
||||
request.tools = active_agent['tools'] + self._get_handoff_tool()
|
||||
if 'model' in active_agent:
|
||||
request.model = init_chat_model(active_agent['model'])
|
||||
if 'model_settings' in active_agent:
|
||||
request.model_settings = active_agent['model_settings']
|
||||
if 'response_format' in active_agent:
|
||||
request.response_format = active_agent['response_format']
|
||||
return request
|
||||
|
||||
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
|
||||
messages = state["messages"]
|
||||
active_agent = state.get('active_agent')
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return None
|
||||
if active_agent is not None:
|
||||
handoffs = []
|
||||
for tool_call in last_message.tool_calls:
|
||||
if tool_call['name'] == _HANDBACK_NAME:
|
||||
handoffs.append(tool_call)
|
||||
if len(handoffs) == 0:
|
||||
return None
|
||||
elif len(handoffs) > 1:
|
||||
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg,
|
||||
last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
else:
|
||||
tool_call = handoffs[0]
|
||||
return {
|
||||
"messages": [{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": f"Handed back to supervisor",
|
||||
}],
|
||||
"active_agent": None,
|
||||
"jump_to": "model"
|
||||
}
|
||||
handoff_tools = self._get_main_handoff_tools()
|
||||
handoff_tool_names = [t.name for t in handoff_tools]
|
||||
|
||||
handoffs = []
|
||||
for tool_call in last_message.tool_calls:
|
||||
if tool_call['name'] in handoff_tool_names:
|
||||
handoffs.append(tool_call)
|
||||
if len(handoffs) == 0:
|
||||
return
|
||||
elif len(handoffs) > 1:
|
||||
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg,
|
||||
last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
else:
|
||||
tool_call = handoffs[0]
|
||||
handoff_to = tool_call['name'][len("handoff_to_"):]
|
||||
return {
|
||||
"messages":[{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": f"Handed off to agent {handoff_to}",
|
||||
}],
|
||||
"active_agent":handoff_to,
|
||||
"jump_to": "model"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
35
libs/langchain_v1/langchain/agents/middleware/rag.py
Normal file
35
libs/langchain_v1/langchain/agents/middleware/rag.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentJump, AgentUpdate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.tools import tool
|
||||
import uuid
|
||||
|
||||
class RAGMiddleware(AgentMiddleware):
|
||||
|
||||
@classmethod
|
||||
def from_retriever(cls, retriever: BaseRetriever, description: str):
|
||||
@tool(description=description)
|
||||
def retrieve(query: str):
|
||||
return retriever.get_relevant_documents(query)
|
||||
|
||||
return cls(retrieve)
|
||||
|
||||
|
||||
def __init__(self, tool):
|
||||
self.tool = tool
|
||||
|
||||
@property
|
||||
def tools(self):
|
||||
return [self.tool]
|
||||
|
||||
def before_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
|
||||
if len(state['messages']) == 1:
|
||||
forced_tool_call = {
|
||||
"type": "tool_call",
|
||||
"name": self.tool.name,
|
||||
"args": {"query": state['messages'][0].content},
|
||||
"id": str(uuid.uuid4()),
|
||||
}
|
||||
return {
|
||||
"messages": [{"role": "ai", "content": None, "tool_calls": [forced_tool_call]}],
|
||||
"jump_to": "tools"
|
||||
}
|
14
libs/langchain_v1/langchain/agents/middleware/reflection.py
Normal file
14
libs/langchain_v1/langchain/agents/middleware/reflection.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
|
||||
|
||||
class ReflectionMiddleware(AgentMiddleware):
|
||||
|
||||
def __init__(self, reflection_step):
|
||||
self.reflection_step = reflection_step
|
||||
|
||||
def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None:
|
||||
reflection = self.reflection_step(state)
|
||||
if reflection:
|
||||
return {
|
||||
"messages": [{'role': 'user', 'content': reflection}],
|
||||
"jump_to": "model"
|
||||
}
|
129
libs/langchain_v1/langchain/agents/middleware/supervisor.py
Normal file
129
libs/langchain_v1/langchain/agents/middleware/supervisor.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import uuid
|
||||
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
|
||||
from typing_extensions import TypedDict, Type
|
||||
from langchain.tools import tool
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
|
||||
|
||||
class Agent(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
prompt: str
|
||||
tools: list
|
||||
model: str
|
||||
model_settings: dict
|
||||
response_format: Type
|
||||
|
||||
|
||||
class SwarmAgentState(AgentState):
|
||||
active_agent: str | None
|
||||
|
||||
|
||||
class SwarmMiddleware(AgentMiddleware):
|
||||
|
||||
state_schema = SwarmAgentState
|
||||
|
||||
def __init__(self, agents: list[Agent], starting_agent: str):
|
||||
self.agents = agents
|
||||
self.starting_agent = starting_agent
|
||||
self.agent_mapping = {a['name']: a for a in agents}
|
||||
|
||||
@property
|
||||
def tools(self):
|
||||
return [t for a in self.agents for t in a['tools']]
|
||||
|
||||
def _get_handoff_tool(self, agent: Agent):
|
||||
@tool(
|
||||
name_or_callable=f"handoff_to_{agent['name']}",
|
||||
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
|
||||
)
|
||||
def handoff():
|
||||
pass
|
||||
|
||||
return handoff
|
||||
|
||||
|
||||
def _get_main_handoff_tools(self):
|
||||
tools = []
|
||||
for agent in self.agents:
|
||||
tools.append(self._get_handoff_tool(agent))
|
||||
return tools
|
||||
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
|
||||
if state.get('active_agent') is None:
|
||||
request.tools = request.tools + self._get_main_handoff_tools()
|
||||
return request
|
||||
active_agent = self.agent_mapping[state['active_agent']]
|
||||
request.system_prompt = active_agent['prompt']
|
||||
request.tools = active_agent['tools']
|
||||
if 'model' in active_agent:
|
||||
request.model = init_chat_model(active_agent['model'])
|
||||
if 'model_settings' in active_agent:
|
||||
request.model_settings = active_agent['model_settings']
|
||||
if 'response_format' in active_agent:
|
||||
request.response_format = active_agent['response_format']
|
||||
return request
|
||||
|
||||
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
|
||||
messages = state["messages"]
|
||||
active_agent = state.get('active_agent')
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
if active_agent is None:
|
||||
return None
|
||||
else:
|
||||
fake_tool_call_id = str(uuid.uuid4())
|
||||
last_message.tool_calls = [{
|
||||
"type": "tool_call",
|
||||
"name": "hand_back_to_supervisor",
|
||||
"args": {},
|
||||
"id": fake_tool_call_id,
|
||||
}]
|
||||
fake_tool_message = {"role": "tool", "content": "Handed back to supervisor", "tool_call_id": fake_tool_call_id}
|
||||
return {
|
||||
"messages": [last_message, fake_tool_message],
|
||||
"jump_to": "model"
|
||||
}
|
||||
if active_agent is not None:
|
||||
return None
|
||||
handoff_tools = self._get_main_handoff_tools()
|
||||
handoff_tool_names = [t.name for t in handoff_tools]
|
||||
|
||||
handoffs = []
|
||||
for tool_call in last_message.tool_calls:
|
||||
if tool_call['name'] in handoff_tool_names:
|
||||
handoffs.append(tool_call)
|
||||
if len(handoffs) == 0:
|
||||
return
|
||||
elif len(handoffs) > 1:
|
||||
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg,
|
||||
last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
else:
|
||||
tool_call = handoffs[0]
|
||||
handoff_to = tool_call['name'][len("handoff_to_"):]
|
||||
return {
|
||||
"messages":[{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": f"Handed off to agent {handoff_to}",
|
||||
}],
|
||||
"active_agent":handoff_to,
|
||||
"jump_to": "model"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
111
libs/langchain_v1/langchain/agents/middleware/swarm.py
Normal file
111
libs/langchain_v1/langchain/agents/middleware/swarm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump
|
||||
from typing_extensions import TypedDict, Type
|
||||
from langchain.tools import tool
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.agents.middleware._utils import _generate_correction_tool_messages
|
||||
|
||||
|
||||
class Agent(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
prompt: str
|
||||
tools: list
|
||||
model: str
|
||||
model_settings: dict
|
||||
response_format: Type
|
||||
|
||||
|
||||
class SwarmAgentState(AgentState):
|
||||
active_agent: str
|
||||
|
||||
|
||||
class SwarmMiddleware(AgentMiddleware):
|
||||
|
||||
state_schema = SwarmAgentState
|
||||
|
||||
def __init__(self, agents: list[Agent], starting_agent: str):
|
||||
self.agents = agents
|
||||
self.starting_agent = starting_agent
|
||||
self.agent_mapping = {a['name']: a for a in agents}
|
||||
|
||||
@property
|
||||
def tools(self):
|
||||
return [t for a in self.agents for t in a['tools']]
|
||||
|
||||
|
||||
def _get_handoff_tool(self, agent: Agent):
|
||||
@tool(
|
||||
name_or_callable=f"handoff_to_{agent['name']}",
|
||||
description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}"
|
||||
)
|
||||
def handoff():
|
||||
pass
|
||||
|
||||
return handoff
|
||||
|
||||
|
||||
def _get_handoff_tools(self, active_agent: str):
|
||||
tools = []
|
||||
for agent in self.agents:
|
||||
if agent['name'] != active_agent:
|
||||
tools.append(self._get_handoff_tool(agent))
|
||||
return tools
|
||||
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest:
|
||||
active_agent_name = state.get('active_agent') or self.starting_agent
|
||||
active_agent = self.agent_mapping[active_agent_name]
|
||||
request.system_prompt = active_agent['prompt']
|
||||
request.tools = active_agent['tools'] + self._get_handoff_tools(active_agent)
|
||||
if 'model' in active_agent:
|
||||
request.model = init_chat_model(active_agent['model'])
|
||||
if 'model_settings' in active_agent:
|
||||
request.model_settings = active_agent['model_settings']
|
||||
if 'response_format' in active_agent:
|
||||
request.response_format = active_agent['response_format']
|
||||
return request
|
||||
|
||||
def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None:
|
||||
active_agent = state.get('active_agent') or self.starting_agent
|
||||
messages = state["messages"]
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
|
||||
return None
|
||||
|
||||
handoff_tools = self._get_handoff_tools(active_agent)
|
||||
handoff_tool_names = [t.name for t in handoff_tools]
|
||||
|
||||
handoffs = []
|
||||
for tool_call in last_message.tool_calls:
|
||||
if tool_call['name'] in handoff_tool_names:
|
||||
handoffs.append(tool_call)
|
||||
if len(handoffs) == 0:
|
||||
return
|
||||
elif len(handoffs) > 1:
|
||||
msg = "Multiple handoffs at the same time are not supported, please just call one at a time."
|
||||
return {
|
||||
"messages": _generate_correction_tool_messages(msg,
|
||||
last_message.tool_calls),
|
||||
"jump_to": "model"
|
||||
}
|
||||
else:
|
||||
tool_call = handoffs[0]
|
||||
handoff_to = tool_call['name'][len("handoff_to_"):]
|
||||
return {
|
||||
"messages":[{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"content": f"Handed off to agent {handoff_to}",
|
||||
}],
|
||||
"active_agent":handoff_to,
|
||||
"jump_to": "model"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -130,7 +130,7 @@ def create_agent(
|
||||
native_output_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
response_format.schema_spec
|
||||
)
|
||||
|
||||
middleware_tools = [t for m in middleware for t in m.tools]
|
||||
# Setup tools
|
||||
if isinstance(tools, list):
|
||||
# Extract builtin provider tools (dict format)
|
||||
@@ -139,25 +139,28 @@ def create_agent(
|
||||
|
||||
# Add structured output tools to regular tools
|
||||
structured_tools = [info.tool for info in structured_output_tools.values()]
|
||||
all_tools = [t for m in middleware for t in m.tools] + regular_tools + structured_tools
|
||||
all_tools = middleware_tools + regular_tools + structured_tools
|
||||
|
||||
if all_tools: # Only create ToolNode if we have tools
|
||||
tool_node = ToolNode(tools=all_tools)
|
||||
else:
|
||||
tool_node = None
|
||||
default_tools = regular_tools + builtin_tools + structured_tools
|
||||
default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
|
||||
else:
|
||||
# tools is ToolNode or None
|
||||
tool_node = tools
|
||||
if tool_node:
|
||||
default_tools = list(tool_node.tools_by_name.values())
|
||||
default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
# Update tool node to know about tools provided by middleware
|
||||
all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
tool_node = ToolNode(all_tools)
|
||||
# Add structured output tools
|
||||
for info in structured_output_tools.values():
|
||||
default_tools.append(info.tool)
|
||||
else:
|
||||
default_tools = (
|
||||
list(structured_output_tools.values()) if structured_output_tools else []
|
||||
)
|
||||
) + middleware_tools
|
||||
|
||||
# validate middleware
|
||||
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), (
|
||||
@@ -400,8 +403,21 @@ def create_agent(
|
||||
[first_node, "tools", END],
|
||||
)
|
||||
else:
|
||||
# If no tools, just go to END from model
|
||||
graph.add_edge(last_node, END)
|
||||
if last_node == "model_request":
|
||||
# If no tools, just go to END from model
|
||||
graph.add_edge(last_node, END)
|
||||
else:
|
||||
# If after_model, then need to check for jump_to
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
middleware_w_after[0].after_model,
|
||||
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
||||
END,
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
|
||||
)
|
||||
|
||||
|
||||
# Add middleware edges (same as before)
|
||||
if middleware_w_before:
|
||||
@@ -414,11 +430,16 @@ def create_agent(
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
if middleware_w_modify_model_request:
|
||||
first_modify = middleware_w_modify_model_request[0]
|
||||
next_node = f"{first_modify.__class__.__name__}.modify_model_request"
|
||||
else:
|
||||
next_node = "model_request"
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
middleware_w_before[-1].before_model,
|
||||
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
||||
"model_request",
|
||||
next_node,
|
||||
first_node,
|
||||
tools_available=tool_node is not None,
|
||||
)
|
||||
@@ -530,7 +551,9 @@ def _add_middleware_edge(
|
||||
def jump_edge(state: AgentState) -> str:
|
||||
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
||||
|
||||
destinations = [default_destination, END]
|
||||
destinations = [default_destination]
|
||||
if END != default_destination:
|
||||
destinations.append(END)
|
||||
if tools_available:
|
||||
destinations.append("tools")
|
||||
if name != model_destination:
|
||||
|
112
libs/langchain_v1/langchain/chat_models/fake.py
Normal file
112
libs/langchain_v1/langchain/chat_models/fake.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ToolCall,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
StructuredResponseT = TypeVar("StructuredResponseT")
|
||||
|
||||
|
||||
class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None
|
||||
structured_response: StructuredResponseT | None = None
|
||||
index: int = 0
|
||||
tool_style: Literal["openai", "anthropic"] = "openai"
|
||||
tools: list = []
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
rf = kwargs.get("response_format")
|
||||
is_native = isinstance(rf, dict) and rf.get("type") == "json_schema"
|
||||
|
||||
if self.tool_calls:
|
||||
if is_native:
|
||||
tool_calls = (
|
||||
self.tool_calls[self.index] if self.index < len(self.tool_calls) else []
|
||||
)
|
||||
else:
|
||||
tool_calls = self.tool_calls[self.index % len(self.tool_calls)]
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
if is_native and not tool_calls:
|
||||
if isinstance(self.structured_response, BaseModel):
|
||||
content_obj = self.structured_response.model_dump()
|
||||
elif is_dataclass(self.structured_response):
|
||||
content_obj = asdict(self.structured_response)
|
||||
elif isinstance(self.structured_response, dict):
|
||||
content_obj = self.structured_response
|
||||
message = AIMessage(content=json.dumps(content_obj), id=str(self.index))
|
||||
else:
|
||||
messages_string = "-".join([m.content for m in messages]) + str(kwargs) + str(self.tools)
|
||||
message = AIMessage(
|
||||
content=messages_string,
|
||||
id=str(self.index),
|
||||
tool_calls=tool_calls.copy(),
|
||||
)
|
||||
self.index += 1
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-tool-call-model"
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
self.tools = tools
|
||||
if len(tools) == 0:
|
||||
msg = "Must provide at least one tool"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_dicts = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_dicts.append(tool)
|
||||
continue
|
||||
if not isinstance(tool, BaseTool):
|
||||
continue
|
||||
|
||||
# NOTE: this is a simplified tool spec for testing purposes only
|
||||
if self.tool_style == "openai":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif self.tool_style == "anthropic":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
}
|
||||
)
|
||||
|
||||
return self.bind(tools=tool_dicts)
|
9
libs/langchain_v1/scripts/deepagent.py
Normal file
9
libs/langchain_v1/scripts/deepagent.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.deepagents import DeepAgentMiddleware
|
||||
from langchain.chat_models.fake import FakeToolCallingModel
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, [], middleware=[DeepAgentMiddleware()])
|
||||
|
||||
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
|
||||
print(s)
|
24
libs/langchain_v1/scripts/rag.py
Normal file
24
libs/langchain_v1/scripts/rag.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.rag import RAGMiddleware
|
||||
from langchain.chat_models.fake import FakeToolCallingModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
|
||||
|
||||
class FakeRetriever(BaseRetriever):
|
||||
|
||||
def _get_relevant_documents(self, query: str, *,
|
||||
run_manager: CallbackManagerForRetrieverRun) -> list[
|
||||
Document]:
|
||||
return [Document(page_content="foo")]
|
||||
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
middleware = RAGMiddleware.from_retriever(FakeRetriever(), "foo")
|
||||
agent = create_agent(model, [], middleware=[middleware])
|
||||
print(agent.get_graph())
|
||||
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
|
||||
print(s)
|
15
libs/langchain_v1/scripts/swarm.py
Normal file
15
libs/langchain_v1/scripts/swarm.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.swarm import SwarmMiddleware
|
||||
from langchain.chat_models.fake import FakeToolCallingModel
|
||||
|
||||
tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []]
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
subagents = [
|
||||
{"name": "foo1", "description": "bar1", "prompt": "hi", "tools": []},
|
||||
{"name": "foo2", "description": "bar1", "prompt": "bye", "tools": []}
|
||||
]
|
||||
middleware = SwarmMiddleware(agents=subagents, starting_agent="foo1")
|
||||
agent = create_agent(model, [], middleware=[middleware])
|
||||
print(agent.get_graph())
|
||||
for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"):
|
||||
print(s)
|
Reference in New Issue
Block a user