mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 08:40:36 +00:00
Compare commits
4 Commits
langchain=
...
sr/agent-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea3ec45441 | ||
|
|
28c02783fa | ||
|
|
22e7deb4b7 | ||
|
|
98122b040b |
@@ -18,6 +18,7 @@ from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentRuntime,
|
||||
AgentState,
|
||||
JumpTo,
|
||||
ModelRequest,
|
||||
@@ -1018,6 +1019,9 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
# Create flat AgentRuntime with all runtime properties
|
||||
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
@@ -1026,7 +1030,7 @@ def create_agent( # noqa: PLR0915
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
runtime=agent_runtime,
|
||||
)
|
||||
|
||||
if wrap_model_call_handler is None:
|
||||
@@ -1071,6 +1075,9 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
# Create flat AgentRuntime with all runtime properties
|
||||
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
@@ -1079,7 +1086,7 @@ def create_agent( # noqa: PLR0915
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
runtime=agent_runtime,
|
||||
)
|
||||
|
||||
if awrap_model_call_handler is None:
|
||||
|
||||
@@ -32,6 +32,7 @@ from .tool_retry import ToolRetryMiddleware
|
||||
from .tool_selection import LLMToolSelectorMiddleware
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
AgentRuntime,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
@@ -47,6 +48,7 @@ from .types import (
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentRuntime",
|
||||
"AgentState",
|
||||
"ClearToolUsesEdit",
|
||||
"CodexSandboxExecutionPolicy",
|
||||
|
||||
@@ -27,7 +27,8 @@ from typing import TypeAlias
|
||||
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command # noqa: TC002
|
||||
from langgraph.store.base import BaseStore # noqa: TC002
|
||||
from langgraph.types import Command, StreamWriter # noqa: TC002
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
||||
|
||||
@@ -60,6 +61,75 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRuntime(Generic[ContextT]):
|
||||
"""Runtime context for agent execution, extending LangGraph's Runtime.
|
||||
|
||||
This class provides agent-specific execution context to middleware, including
|
||||
the name of the currently executing graph and all Runtime properties flattened
|
||||
for convenient access.
|
||||
|
||||
The AgentRuntime follows the same pattern as ToolRuntime, providing a flat
|
||||
structure with all runtime properties directly accessible.
|
||||
|
||||
Attributes:
|
||||
agent_name: The name of the currently executing graph/agent. This is the
|
||||
name passed to `create_agent(name=...)` or defaults to "LangGraph".
|
||||
context: Static context for the graph run (e.g., `user_id`, `db_conn`).
|
||||
store: Store for persistence and memory, if configured.
|
||||
stream_writer: Function for writing to the custom stream.
|
||||
previous: The previous return value for the given thread (functional API only).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents.middleware import wrap_model_call, AgentRuntime
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
|
||||
|
||||
@wrap_model_call
|
||||
def log_agent_name(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
'''Log which agent is making the model call.'''
|
||||
agent_name = request.runtime.agent_name
|
||||
print(f"Agent '{agent_name}' is calling the model")
|
||||
|
||||
# Access runtime context directly (flattened)
|
||||
user_id = request.runtime.context.get("user_id")
|
||||
print(f"User: {user_id}")
|
||||
|
||||
return handler(request)
|
||||
```
|
||||
"""
|
||||
|
||||
agent_name: str
|
||||
"""The name of the currently executing graph/agent."""
|
||||
|
||||
context: ContextT = field(default=None) # type: ignore[assignment]
|
||||
"""Static context for the graph run, like `user_id`, `db_conn`, etc."""
|
||||
|
||||
store: BaseStore | None = field(default=None)
|
||||
"""Store for the graph run, enabling persistence and memory."""
|
||||
|
||||
stream_writer: StreamWriter = field(default=None) # type: ignore[assignment]
|
||||
"""Function that writes to the custom stream."""
|
||||
|
||||
previous: Any = field(default=None)
|
||||
"""The previous return value for the given thread."""
|
||||
|
||||
@classmethod
|
||||
def from_runtime(cls, name: str, runtime: Runtime[ContextT]) -> AgentRuntime[ContextT]:
|
||||
"""Create an AgentRuntime from a Runtime."""
|
||||
return AgentRuntime[ContextT](
|
||||
agent_name=name,
|
||||
context=runtime.context,
|
||||
store=runtime.store,
|
||||
stream_writer=runtime.stream_writer,
|
||||
previous=runtime.previous,
|
||||
)
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for ModelRequest.override() method."""
|
||||
|
||||
@@ -74,7 +144,23 @@ class _ModelRequestOverrides(TypedDict, total=False):
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
"""Model request information for the agent.
|
||||
|
||||
This dataclass contains all the information needed for a model invocation,
|
||||
including the model, messages, tools, and runtime context.
|
||||
|
||||
Attributes:
|
||||
model: The chat model to invoke.
|
||||
system_prompt: Optional system prompt to prepend to messages.
|
||||
messages: List of conversation messages (excluding system prompt).
|
||||
tool_choice: Tool selection configuration for the model.
|
||||
tools: Available tools for the model to use.
|
||||
response_format: Structured output format specification.
|
||||
state: Complete agent state at the time of model invocation.
|
||||
runtime: Agent runtime context including agent name and underlying
|
||||
LangGraph Runtime with context, store, and stream_writer.
|
||||
model_settings: Additional model-specific settings.
|
||||
"""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
@@ -83,7 +169,7 @@ class ModelRequest:
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
state: AgentState
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
runtime: AgentRuntime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
@@ -932,7 +1018,7 @@ def before_agent(
|
||||
```python
|
||||
@before_agent
|
||||
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
||||
print(f"Starting agent with {len(state['messages'])} messages")
|
||||
print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages")
|
||||
```
|
||||
|
||||
With conditional jumping:
|
||||
|
||||
108
libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py
Normal file
108
libs/langchain_v1/tests/unit_tests/agents/test_agent_runtime.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Tests for AgentRuntime access via wrap_model_call middleware."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import wrap_model_call
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_chat_model():
|
||||
"""Fixture providing a fake chat model for testing."""
|
||||
return GenericFakeChatModel(messages=iter([AIMessage(content="test response")]))
|
||||
|
||||
|
||||
def test_agent_name_accessible_in_middleware(fake_chat_model):
|
||||
"""Test that agent name can be accessed via middleware."""
|
||||
captured_agent_name = None
|
||||
|
||||
@wrap_model_call
|
||||
def capture_agent_name(request: ModelRequest, handler):
|
||||
nonlocal captured_agent_name
|
||||
captured_agent_name = request.runtime.agent_name
|
||||
return handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
fake_chat_model,
|
||||
tools=[],
|
||||
middleware=[capture_agent_name],
|
||||
name="TestAgent",
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert captured_agent_name == "TestAgent"
|
||||
|
||||
|
||||
def test_nested_agent_name_accessible_in_tool():
|
||||
"""Test that nested agent's name is accessible when agent is used in a tool."""
|
||||
# Track which agent names were captured
|
||||
captured_agent_names = []
|
||||
|
||||
@wrap_model_call
|
||||
def capture_agent_name(request: ModelRequest, handler):
|
||||
captured_agent_names.append(request.runtime.agent_name)
|
||||
return handler(request)
|
||||
|
||||
# Create a nested agent that will be called from within a tool
|
||||
nested_agent = create_agent(
|
||||
FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[capture_agent_name],
|
||||
name="NestedAgent",
|
||||
)
|
||||
|
||||
# Create a tool that invokes the nested agent
|
||||
@tool
|
||||
def call_nested_agent(query: str, runtime: ToolRuntime) -> str:
|
||||
"""Tool that calls a nested agent."""
|
||||
result = nested_agent.invoke({"messages": [HumanMessage(query)]})
|
||||
return result["messages"][-1].content
|
||||
|
||||
# Create outer agent that uses the tool
|
||||
outer_agent = create_agent(
|
||||
FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}],
|
||||
[],
|
||||
]
|
||||
),
|
||||
tools=[call_nested_agent],
|
||||
middleware=[capture_agent_name],
|
||||
name="OuterAgent",
|
||||
)
|
||||
|
||||
# Invoke the outer agent, which should call the tool, which calls the nested agent
|
||||
outer_agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Both agents should have captured their names
|
||||
assert "OuterAgent" in captured_agent_names
|
||||
assert "NestedAgent" in captured_agent_names
|
||||
|
||||
|
||||
async def test_agent_name_accessible_in_async_middleware():
|
||||
"""Test that agent name can be accessed in async middleware."""
|
||||
captured_agent_name = None
|
||||
|
||||
@wrap_model_call
|
||||
async def capture_agent_name_async(request: ModelRequest, handler):
|
||||
nonlocal captured_agent_name
|
||||
captured_agent_name = request.runtime.agent_name
|
||||
return await handler(request)
|
||||
|
||||
fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")]))
|
||||
agent = create_agent(
|
||||
fake_model,
|
||||
tools=[],
|
||||
middleware=[capture_agent_name_async],
|
||||
name="AsyncAgent",
|
||||
)
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
assert captured_agent_name == "AsyncAgent"
|
||||
@@ -1351,7 +1351,7 @@ def test_public_private_state_for_custom_middleware() -> None:
|
||||
class CustomMiddleware(AgentMiddleware[CustomState]):
|
||||
state_schema: type[CustomState] = CustomState
|
||||
|
||||
def before_model(self, state: CustomState) -> dict[str, Any]:
|
||||
def before_model(self, state: CustomState, runtime) -> dict[str, Any]:
|
||||
assert "omit_input" not in state
|
||||
assert "omit_output" in state
|
||||
assert "private_state" not in state
|
||||
@@ -1456,11 +1456,11 @@ def test_injected_state_in_middleware_agent() -> None:
|
||||
|
||||
def test_jump_to_is_ephemeral() -> None:
|
||||
class MyMiddleware(AgentMiddleware):
|
||||
def before_model(self, state: AgentState) -> dict[str, Any]:
|
||||
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
|
||||
assert "jump_to" not in state
|
||||
return {"jump_to": "model"}
|
||||
|
||||
def after_model(self, state: AgentState) -> dict[str, Any]:
|
||||
def after_model(self, state: AgentState, runtime) -> dict[str, Any]:
|
||||
assert "jump_to" not in state
|
||||
return {"jump_to": "model"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user