From dccfcd3f08fd6d96fceb6131515f080505f40ab2 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 7 Sep 2025 12:26:43 +0200 Subject: [PATCH] new middleware --- .../langchain/agents/middleware/_utils.py | 9 ++ .../langchain/agents/middleware/deepagents.py | 68 ++++++++ .../agents/middleware/dynamic_prompt.py | 22 +++ .../agents/middleware/human_in_the_loop.py | 20 ++- .../agents/middleware/hybrid_supervisor.py | 145 ++++++++++++++++++ .../langchain/agents/middleware/rag.py | 35 +++++ .../langchain/agents/middleware/reflection.py | 14 ++ .../langchain/agents/middleware/supervisor.py | 129 ++++++++++++++++ .../langchain/agents/middleware/swarm.py | 111 ++++++++++++++ .../langchain/agents/middleware_agent.py | 41 +++-- .../langchain/chat_models/fake.py | 112 ++++++++++++++ libs/langchain_v1/scripts/deepagent.py | 9 ++ libs/langchain_v1/scripts/rag.py | 24 +++ libs/langchain_v1/scripts/swarm.py | 15 ++ 14 files changed, 741 insertions(+), 13 deletions(-) create mode 100644 libs/langchain_v1/langchain/agents/middleware/_utils.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/deepagents.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/dynamic_prompt.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/hybrid_supervisor.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/rag.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/reflection.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/supervisor.py create mode 100644 libs/langchain_v1/langchain/agents/middleware/swarm.py create mode 100644 libs/langchain_v1/langchain/chat_models/fake.py create mode 100644 libs/langchain_v1/scripts/deepagent.py create mode 100644 libs/langchain_v1/scripts/rag.py create mode 100644 libs/langchain_v1/scripts/swarm.py diff --git a/libs/langchain_v1/langchain/agents/middleware/_utils.py b/libs/langchain_v1/langchain/agents/middleware/_utils.py new file mode 100644 index 00000000000..24789e929f0 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/_utils.py @@ -0,0 +1,9 @@ +def _generate_correction_tool_messages(content: str, tool_calls: list): + tool_messages = [] + for tool_call in tool_calls: + tool_messages.append({ + "role": "tool", + "content": content, + "tool_call_id": tool_call["id"] + }) + return tool_messages diff --git a/libs/langchain_v1/langchain/agents/middleware/deepagents.py b/libs/langchain_v1/langchain/agents/middleware/deepagents.py new file mode 100644 index 00000000000..39e119e8267 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/deepagents.py @@ -0,0 +1,68 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest +from typing import NotRequired, Annotated +from typing import Literal +from typing_extensions import TypedDict + + +class Todo(TypedDict): + """Todo to track.""" + + content: str + status: Literal["pending", "in_progress", "completed"] + + +def file_reducer(l, r): + if l is None: + return r + elif r is None: + return l + else: + return {**l, **r} + + +class DeepAgentState(AgentState): + todos: NotRequired[list[Todo]] + files: Annotated[NotRequired[dict[str, str]], file_reducer] + + +from langchain_core.tools import tool, InjectedToolCallId +from langgraph.types import Command +from langchain_core.messages import ToolMessage +from typing import Annotated, Union +from langgraph.prebuilt import InjectedState + +def write_todos( + todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId] +) -> Command: + """write todos""" + return Command( + update={ + "todos": todos, + "messages": [ + ToolMessage(f"Updated todo list to {todos}", tool_call_id=tool_call_id) + ], + } + ) + + +def ls(state: Annotated[DeepAgentState, InjectedState]) -> list[str]: + """List all files""" + return list(state.get("files", {}).keys()) + +class DeepAgentMiddleware(AgentMiddleware): + + state_schema = DeepAgentState + + def __init__(self, subagents: list = []): + self.subagents = subagents + + @property + def tools(self): + return [write_todos, ls] + self.subagents + + def modify_model_request(self, request: ModelRequest, state: DeepAgentState) -> ModelRequest: + if request.system_prompt: + request.system_prompt += "\n\nUse the todo tool to plan as needed" + else: + request.system_prompt = "Use the todo tool to plan as needed" + return request diff --git a/libs/langchain_v1/langchain/agents/middleware/dynamic_prompt.py b/libs/langchain_v1/langchain/agents/middleware/dynamic_prompt.py new file mode 100644 index 00000000000..3ad7ba9fe7b --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/dynamic_prompt.py @@ -0,0 +1,22 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest + +class DynamicPrompt(AgentMiddleware): + + def __init__(self, modifier): + self.modifier = modifier + + def modify_model_request(self, request: ModelRequest, state) -> ModelRequest: + prompt = self.modifier(state) + request.system_prompt = prompt + return request + + +class DynamicMessages(AgentMiddleware): + + def __init__(self, modifier): + self.modifier = modifier + + def modify_model_request(self, request: ModelRequest, state) -> ModelRequest: + messages = self.modifier(state) + request.messages = messages + return request diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index d19932f9203..2aaec16a2f0 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -7,6 +7,7 @@ from langgraph.prebuilt.interrupt import ( from langgraph.types import interrupt from langchain.agents.types import AgentJump, AgentMiddleware, AgentState, AgentUpdate +from langchain.agents.middleware._utils import _generate_correction_tool_messages ToolInterruptConfig = dict[str, HumanInterruptConfig] @@ -50,11 +51,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware): # Right now, we do not support multiple tool calls with interrupts if len(interrupt_tool_calls) > 1: - raise ValueError("Does not currently support multiple tool calls with interrupts") + tool_names = [t['name'] for t in interrupt_tool_calls] + msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time" + return { + "messages": _generate_correction_tool_messages(msg, last_message.tool_calls), + "jump_to": "model" + } # Right now, we do not support interrupting a tool call if other tool calls exist if auto_approved_tool_calls: - raise ValueError("Does not currently support interrupting a tool call if other tool calls exist") + tool_names = [t['name'] for t in interrupt_tool_calls] + msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool." + return { + "messages": _generate_correction_tool_messages(msg, + last_message.tool_calls), + "jump_to": "model" + } # Only one tool call will need interrupts tool_call = interrupt_tool_calls[0] @@ -87,14 +99,14 @@ class HumanInTheLoopMiddleware(AgentMiddleware): } approved_tool_calls.append(new_tool_call) elif response["type"] == "ignore": - return {"goto": "__end__"} + return {"jump_to": "__end__"} elif response["type"] == "response": tool_message = { "role": "tool", "tool_call_id": tool_call["id"], "content": response["args"], } - return {"messages": [tool_message], "goto": "model"} + return {"messages": [tool_message], "jump_to": "model"} else: raise ValueError(f"Unknown response type: {response['type']}") diff --git a/libs/langchain_v1/langchain/agents/middleware/hybrid_supervisor.py b/libs/langchain_v1/langchain/agents/middleware/hybrid_supervisor.py new file mode 100644 index 00000000000..beaab8a0066 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/hybrid_supervisor.py @@ -0,0 +1,145 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump +from typing_extensions import TypedDict, Type +from langchain.tools import tool +from langchain.chat_models import init_chat_model +from langchain.agents.middleware._utils import _generate_correction_tool_messages + +_HANDBACK_NAME = "hand_back" + + +class Agent(TypedDict): + name: str + description: str + prompt: str + tools: list + model: str + model_settings: dict + response_format: Type + + +class SwarmAgentState(AgentState): + active_agent: str | None + + +class SwarmMiddleware(AgentMiddleware): + + state_schema = SwarmAgentState + + def __init__(self, agents: list[Agent], starting_agent: str): + self.agents = agents + self.starting_agent = starting_agent + self.agent_mapping = {a['name']: a for a in agents} + + @property + def tools(self): + return [t for a in self.agents for t in a['tools']] + + def _get_handoff_tool(self, agent: Agent): + @tool( + name_or_callable=f"handoff_to_{agent['name']}", + description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}" + ) + def handoff(): + pass + + return handoff + + def _get_pass_back_tool(self): + @tool(name_or_callable=_HANDBACK_NAME, + description="Call this if you are unable to handle the current request. You will hand back control of the conversation to your supervisor") + def hand_back(): + pass + + return hand_back + + + def _get_main_handoff_tools(self): + tools = [] + for agent in self.agents: + tools.append(self._get_handoff_tool(agent)) + return tools + + + def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest: + if state.get('active_agent') is None: + request.tools = request.tools + self._get_main_handoff_tools() + return request + active_agent = self.agent_mapping[state['active_agent']] + request.system_prompt = active_agent['prompt'] + request.tools = active_agent['tools'] + self._get_handoff_tool() + if 'model' in active_agent: + request.model = init_chat_model(active_agent['model']) + if 'model_settings' in active_agent: + request.model_settings = active_agent['model_settings'] + if 'response_format' in active_agent: + request.response_format = active_agent['response_format'] + return request + + def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None: + messages = state["messages"] + active_agent = state.get('active_agent') + if not messages: + return None + + last_message = messages[-1] + + if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: + return None + if active_agent is not None: + handoffs = [] + for tool_call in last_message.tool_calls: + if tool_call['name'] == _HANDBACK_NAME: + handoffs.append(tool_call) + if len(handoffs) == 0: + return None + elif len(handoffs) > 1: + msg = "Multiple handoffs at the same time are not supported, please just call one at a time." + return { + "messages": _generate_correction_tool_messages(msg, + last_message.tool_calls), + "jump_to": "model" + } + else: + tool_call = handoffs[0] + return { + "messages": [{ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": f"Handed back to supervisor", + }], + "active_agent": None, + "jump_to": "model" + } + handoff_tools = self._get_main_handoff_tools() + handoff_tool_names = [t.name for t in handoff_tools] + + handoffs = [] + for tool_call in last_message.tool_calls: + if tool_call['name'] in handoff_tool_names: + handoffs.append(tool_call) + if len(handoffs) == 0: + return + elif len(handoffs) > 1: + msg = "Multiple handoffs at the same time are not supported, please just call one at a time." + return { + "messages": _generate_correction_tool_messages(msg, + last_message.tool_calls), + "jump_to": "model" + } + else: + tool_call = handoffs[0] + handoff_to = tool_call['name'][len("handoff_to_"):] + return { + "messages":[{ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": f"Handed off to agent {handoff_to}", + }], + "active_agent":handoff_to, + "jump_to": "model" + } + + + + + diff --git a/libs/langchain_v1/langchain/agents/middleware/rag.py b/libs/langchain_v1/langchain/agents/middleware/rag.py new file mode 100644 index 00000000000..38a668fbaa9 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/rag.py @@ -0,0 +1,35 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentJump, AgentUpdate +from langchain_core.retrievers import BaseRetriever +from langchain_core.tools import tool +import uuid + +class RAGMiddleware(AgentMiddleware): + + @classmethod + def from_retriever(cls, retriever: BaseRetriever, description: str): + @tool(description=description) + def retrieve(query: str): + return retriever.get_relevant_documents(query) + + return cls(retrieve) + + + def __init__(self, tool): + self.tool = tool + + @property + def tools(self): + return [self.tool] + + def before_model(self, state: AgentState) -> AgentUpdate | AgentJump | None: + if len(state['messages']) == 1: + forced_tool_call = { + "type": "tool_call", + "name": self.tool.name, + "args": {"query": state['messages'][0].content}, + "id": str(uuid.uuid4()), + } + return { + "messages": [{"role": "ai", "content": None, "tool_calls": [forced_tool_call]}], + "jump_to": "tools" + } diff --git a/libs/langchain_v1/langchain/agents/middleware/reflection.py b/libs/langchain_v1/langchain/agents/middleware/reflection.py new file mode 100644 index 00000000000..533b44017ce --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/reflection.py @@ -0,0 +1,14 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump + +class ReflectionMiddleware(AgentMiddleware): + + def __init__(self, reflection_step): + self.reflection_step = reflection_step + + def after_model(self, state: AgentState) -> AgentUpdate | AgentJump | None: + reflection = self.reflection_step(state) + if reflection: + return { + "messages": [{'role': 'user', 'content': reflection}], + "jump_to": "model" + } diff --git a/libs/langchain_v1/langchain/agents/middleware/supervisor.py b/libs/langchain_v1/langchain/agents/middleware/supervisor.py new file mode 100644 index 00000000000..4b7151951fa --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/supervisor.py @@ -0,0 +1,129 @@ +import uuid + +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump +from typing_extensions import TypedDict, Type +from langchain.tools import tool +from langchain.chat_models import init_chat_model + +from langchain.agents.middleware._utils import _generate_correction_tool_messages + + +class Agent(TypedDict): + name: str + description: str + prompt: str + tools: list + model: str + model_settings: dict + response_format: Type + + +class SwarmAgentState(AgentState): + active_agent: str | None + + +class SwarmMiddleware(AgentMiddleware): + + state_schema = SwarmAgentState + + def __init__(self, agents: list[Agent], starting_agent: str): + self.agents = agents + self.starting_agent = starting_agent + self.agent_mapping = {a['name']: a for a in agents} + + @property + def tools(self): + return [t for a in self.agents for t in a['tools']] + + def _get_handoff_tool(self, agent: Agent): + @tool( + name_or_callable=f"handoff_to_{agent['name']}", + description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}" + ) + def handoff(): + pass + + return handoff + + + def _get_main_handoff_tools(self): + tools = [] + for agent in self.agents: + tools.append(self._get_handoff_tool(agent)) + return tools + + + def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest: + if state.get('active_agent') is None: + request.tools = request.tools + self._get_main_handoff_tools() + return request + active_agent = self.agent_mapping[state['active_agent']] + request.system_prompt = active_agent['prompt'] + request.tools = active_agent['tools'] + if 'model' in active_agent: + request.model = init_chat_model(active_agent['model']) + if 'model_settings' in active_agent: + request.model_settings = active_agent['model_settings'] + if 'response_format' in active_agent: + request.response_format = active_agent['response_format'] + return request + + def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None: + messages = state["messages"] + active_agent = state.get('active_agent') + if not messages: + return None + + last_message = messages[-1] + + if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: + if active_agent is None: + return None + else: + fake_tool_call_id = str(uuid.uuid4()) + last_message.tool_calls = [{ + "type": "tool_call", + "name": "hand_back_to_supervisor", + "args": {}, + "id": fake_tool_call_id, + }] + fake_tool_message = {"role": "tool", "content": "Handed back to supervisor", "tool_call_id": fake_tool_call_id} + return { + "messages": [last_message, fake_tool_message], + "jump_to": "model" + } + if active_agent is not None: + return None + handoff_tools = self._get_main_handoff_tools() + handoff_tool_names = [t.name for t in handoff_tools] + + handoffs = [] + for tool_call in last_message.tool_calls: + if tool_call['name'] in handoff_tool_names: + handoffs.append(tool_call) + if len(handoffs) == 0: + return + elif len(handoffs) > 1: + msg = "Multiple handoffs at the same time are not supported, please just call one at a time." + return { + "messages": _generate_correction_tool_messages(msg, + last_message.tool_calls), + "jump_to": "model" + } + else: + tool_call = handoffs[0] + handoff_to = tool_call['name'][len("handoff_to_"):] + return { + "messages":[{ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": f"Handed off to agent {handoff_to}", + }], + "active_agent":handoff_to, + "jump_to": "model" + } + + + + + diff --git a/libs/langchain_v1/langchain/agents/middleware/swarm.py b/libs/langchain_v1/langchain/agents/middleware/swarm.py new file mode 100644 index 00000000000..c80f3f20ed6 --- /dev/null +++ b/libs/langchain_v1/langchain/agents/middleware/swarm.py @@ -0,0 +1,111 @@ +from langchain.agents.types import AgentMiddleware, AgentState, ModelRequest, AgentUpdate, AgentJump +from typing_extensions import TypedDict, Type +from langchain.tools import tool +from langchain.chat_models import init_chat_model +from langchain.agents.middleware._utils import _generate_correction_tool_messages + + +class Agent(TypedDict): + name: str + description: str + prompt: str + tools: list + model: str + model_settings: dict + response_format: Type + + +class SwarmAgentState(AgentState): + active_agent: str + + +class SwarmMiddleware(AgentMiddleware): + + state_schema = SwarmAgentState + + def __init__(self, agents: list[Agent], starting_agent: str): + self.agents = agents + self.starting_agent = starting_agent + self.agent_mapping = {a['name']: a for a in agents} + + @property + def tools(self): + return [t for a in self.agents for t in a['tools']] + + + def _get_handoff_tool(self, agent: Agent): + @tool( + name_or_callable=f"handoff_to_{agent['name']}", + description=f"Handoff to agent {agent['name']}. Description of this agent:\n\n{agent['description']}" + ) + def handoff(): + pass + + return handoff + + + def _get_handoff_tools(self, active_agent: str): + tools = [] + for agent in self.agents: + if agent['name'] != active_agent: + tools.append(self._get_handoff_tool(agent)) + return tools + + + def modify_model_request(self, request: ModelRequest, state: SwarmAgentState) -> ModelRequest: + active_agent_name = state.get('active_agent') or self.starting_agent + active_agent = self.agent_mapping[active_agent_name] + request.system_prompt = active_agent['prompt'] + request.tools = active_agent['tools'] + self._get_handoff_tools(active_agent) + if 'model' in active_agent: + request.model = init_chat_model(active_agent['model']) + if 'model_settings' in active_agent: + request.model_settings = active_agent['model_settings'] + if 'response_format' in active_agent: + request.response_format = active_agent['response_format'] + return request + + def after_model(self, state: SwarmAgentState) -> AgentUpdate | AgentJump | None: + active_agent = state.get('active_agent') or self.starting_agent + messages = state["messages"] + if not messages: + return None + + last_message = messages[-1] + + if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: + return None + + handoff_tools = self._get_handoff_tools(active_agent) + handoff_tool_names = [t.name for t in handoff_tools] + + handoffs = [] + for tool_call in last_message.tool_calls: + if tool_call['name'] in handoff_tool_names: + handoffs.append(tool_call) + if len(handoffs) == 0: + return + elif len(handoffs) > 1: + msg = "Multiple handoffs at the same time are not supported, please just call one at a time." + return { + "messages": _generate_correction_tool_messages(msg, + last_message.tool_calls), + "jump_to": "model" + } + else: + tool_call = handoffs[0] + handoff_to = tool_call['name'][len("handoff_to_"):] + return { + "messages":[{ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": f"Handed off to agent {handoff_to}", + }], + "active_agent":handoff_to, + "jump_to": "model" + } + + + + + diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index cee9f85f067..2d36dc0d29f 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -130,7 +130,7 @@ def create_agent( native_output_binding = ProviderStrategyBinding.from_schema_spec( response_format.schema_spec ) - + middleware_tools = [t for m in middleware for t in m.tools] # Setup tools if isinstance(tools, list): # Extract builtin provider tools (dict format) @@ -139,25 +139,28 @@ def create_agent( # Add structured output tools to regular tools structured_tools = [info.tool for info in structured_output_tools.values()] - all_tools = [t for m in middleware for t in m.tools] + regular_tools + structured_tools + all_tools = middleware_tools + regular_tools + structured_tools if all_tools: # Only create ToolNode if we have tools tool_node = ToolNode(tools=all_tools) else: tool_node = None - default_tools = regular_tools + builtin_tools + structured_tools + default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools else: # tools is ToolNode or None tool_node = tools if tool_node: - default_tools = list(tool_node.tools_by_name.values()) + default_tools = list(tool_node.tools_by_name.values()) + middleware_tools + # Update tool node to know about tools provided by middleware + all_tools = list(tool_node.tools_by_name.values()) + middleware_tools + tool_node = ToolNode(all_tools) # Add structured output tools for info in structured_output_tools.values(): default_tools.append(info.tool) else: default_tools = ( list(structured_output_tools.values()) if structured_output_tools else [] - ) + ) + middleware_tools # validate middleware assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( @@ -400,8 +403,21 @@ def create_agent( [first_node, "tools", END], ) else: - # If no tools, just go to END from model - graph.add_edge(last_node, END) + if last_node == "model_request": + # If no tools, just go to END from model + graph.add_edge(last_node, END) + else: + # If after_model, then need to check for jump_to + _add_middleware_edge( + graph, + middleware_w_after[0].after_model, + f"{middleware_w_after[0].__class__.__name__}.after_model", + END, + first_node, + tools_available=tool_node is not None, + + ) + # Add middleware edges (same as before) if middleware_w_before: @@ -414,11 +430,16 @@ def create_agent( first_node, tools_available=tool_node is not None, ) + if middleware_w_modify_model_request: + first_modify = middleware_w_modify_model_request[0] + next_node = f"{first_modify.__class__.__name__}.modify_model_request" + else: + next_node = "model_request" _add_middleware_edge( graph, middleware_w_before[-1].before_model, f"{middleware_w_before[-1].__class__.__name__}.before_model", - "model_request", + next_node, first_node, tools_available=tool_node is not None, ) @@ -530,7 +551,9 @@ def _add_middleware_edge( def jump_edge(state: AgentState) -> str: return _resolve_jump(state.get("jump_to"), model_destination) or default_destination - destinations = [default_destination, END] + destinations = [default_destination] + if END != default_destination: + destinations.append(END) if tools_available: destinations.append("tools") if name != model_destination: diff --git a/libs/langchain_v1/langchain/chat_models/fake.py b/libs/langchain_v1/langchain/chat_models/fake.py new file mode 100644 index 00000000000..d844f84b7cb --- /dev/null +++ b/libs/langchain_v1/langchain/chat_models/fake.py @@ -0,0 +1,112 @@ +import json +from collections.abc import Callable, Sequence +from dataclasses import asdict, is_dataclass +from typing import ( + Any, + Generic, + Literal, + TypeVar, + Union, +) + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ToolCall, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from pydantic import BaseModel + +StructuredResponseT = TypeVar("StructuredResponseT") + + +class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]): + tool_calls: Union[list[list[ToolCall]], list[list[dict]]] | None = None + structured_response: StructuredResponseT | None = None + index: int = 0 + tool_style: Literal["openai", "anthropic"] = "openai" + tools: list = [] + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + rf = kwargs.get("response_format") + is_native = isinstance(rf, dict) and rf.get("type") == "json_schema" + + if self.tool_calls: + if is_native: + tool_calls = ( + self.tool_calls[self.index] if self.index < len(self.tool_calls) else [] + ) + else: + tool_calls = self.tool_calls[self.index % len(self.tool_calls)] + else: + tool_calls = [] + + if is_native and not tool_calls: + if isinstance(self.structured_response, BaseModel): + content_obj = self.structured_response.model_dump() + elif is_dataclass(self.structured_response): + content_obj = asdict(self.structured_response) + elif isinstance(self.structured_response, dict): + content_obj = self.structured_response + message = AIMessage(content=json.dumps(content_obj), id=str(self.index)) + else: + messages_string = "-".join([m.content for m in messages]) + str(kwargs) + str(self.tools) + message = AIMessage( + content=messages_string, + id=str(self.index), + tool_calls=tool_calls.copy(), + ) + self.index += 1 + return ChatResult(generations=[ChatGeneration(message=message)]) + + @property + def _llm_type(self) -> str: + return "fake-tool-call-model" + + def bind_tools( + self, + tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + self.tools = tools + if len(tools) == 0: + msg = "Must provide at least one tool" + raise ValueError(msg) + + tool_dicts = [] + for tool in tools: + if isinstance(tool, dict): + tool_dicts.append(tool) + continue + if not isinstance(tool, BaseTool): + continue + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) diff --git a/libs/langchain_v1/scripts/deepagent.py b/libs/langchain_v1/scripts/deepagent.py new file mode 100644 index 00000000000..d301aebd2e8 --- /dev/null +++ b/libs/langchain_v1/scripts/deepagent.py @@ -0,0 +1,9 @@ +from langchain.agents import create_agent +from langchain.agents.middleware.deepagents import DeepAgentMiddleware +from langchain.chat_models.fake import FakeToolCallingModel + +model = FakeToolCallingModel() +agent = create_agent(model, [], middleware=[DeepAgentMiddleware()]) + +for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"): + print(s) diff --git a/libs/langchain_v1/scripts/rag.py b/libs/langchain_v1/scripts/rag.py new file mode 100644 index 00000000000..44c109d9f3a --- /dev/null +++ b/libs/langchain_v1/scripts/rag.py @@ -0,0 +1,24 @@ +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document + +from langchain.agents import create_agent +from langchain.agents.middleware.rag import RAGMiddleware +from langchain.chat_models.fake import FakeToolCallingModel +from langchain_core.retrievers import BaseRetriever + +tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []] + +class FakeRetriever(BaseRetriever): + + def _get_relevant_documents(self, query: str, *, + run_manager: CallbackManagerForRetrieverRun) -> list[ + Document]: + return [Document(page_content="foo")] + + +model = FakeToolCallingModel() +middleware = RAGMiddleware.from_retriever(FakeRetriever(), "foo") +agent = create_agent(model, [], middleware=[middleware]) +print(agent.get_graph()) +for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"): + print(s) diff --git a/libs/langchain_v1/scripts/swarm.py b/libs/langchain_v1/scripts/swarm.py new file mode 100644 index 00000000000..d1950bd9163 --- /dev/null +++ b/libs/langchain_v1/scripts/swarm.py @@ -0,0 +1,15 @@ +from langchain.agents import create_agent +from langchain.agents.middleware.swarm import SwarmMiddleware +from langchain.chat_models.fake import FakeToolCallingModel + +tool_calls = [[{"args": {}, "id": "1", "name": "handoff_to_foo2"}], []] +model = FakeToolCallingModel(tool_calls=tool_calls) +subagents = [ + {"name": "foo1", "description": "bar1", "prompt": "hi", "tools": []}, + {"name": "foo2", "description": "bar1", "prompt": "bye", "tools": []} +] +middleware = SwarmMiddleware(agents=subagents, starting_agent="foo1") +agent = create_agent(model, [], middleware=[middleware]) +print(agent.get_graph()) +for s in agent.stream({"messages": [{"role": "user", "content": "hi"}]}, stream_mode="debug"): + print(s)