Compare commits

...

4 Commits

Author SHA1 Message Date
Sydney Runkle
ea3ec45441 better tests 2025-10-27 14:50:33 -07:00
Sydney Runkle
28c02783fa tests 2025-10-27 14:44:46 -07:00
Sydney Runkle
22e7deb4b7 flat 2025-10-27 11:13:57 -07:00
Sydney Runkle
98122b040b agent runtime poc 2025-10-27 11:01:31 -07:00
5 changed files with 212 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ from typing_extensions import NotRequired, Required, TypedDict, TypeVar
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentRuntime,
AgentState,
JumpTo,
ModelRequest,
@@ -1018,6 +1019,9 @@ def create_agent( # noqa: PLR0915
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)
request = ModelRequest(
model=model,
tools=default_tools,
@@ -1026,7 +1030,7 @@ def create_agent( # noqa: PLR0915
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)
if wrap_model_call_handler is None:
@@ -1071,6 +1075,9 @@ def create_agent( # noqa: PLR0915
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Create flat AgentRuntime with all runtime properties
agent_runtime = AgentRuntime.from_runtime(name or "agent", runtime)
request = ModelRequest(
model=model,
tools=default_tools,
@@ -1079,7 +1086,7 @@ def create_agent( # noqa: PLR0915
messages=state["messages"],
tool_choice=None,
state=state,
runtime=runtime,
runtime=agent_runtime,
)
if awrap_model_call_handler is None:

View File

@@ -32,6 +32,7 @@ from .tool_retry import ToolRetryMiddleware
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentRuntime,
AgentState,
ModelRequest,
ModelResponse,
@@ -47,6 +48,7 @@ from .types import (
__all__ = [
"AgentMiddleware",
"AgentRuntime",
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",

View File

@@ -27,7 +27,8 @@ from typing import TypeAlias
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.types import Command # noqa: TC002
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.types import Command, StreamWriter # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
@@ -60,6 +61,75 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
@dataclass
class AgentRuntime(Generic[ContextT]):
"""Runtime context for agent execution, extending LangGraph's Runtime.
This class provides agent-specific execution context to middleware, including
the name of the currently executing graph and all Runtime properties flattened
for convenient access.
The AgentRuntime follows the same pattern as ToolRuntime, providing a flat
structure with all runtime properties directly accessible.
Attributes:
agent_name: The name of the currently executing graph/agent. This is the
name passed to `create_agent(name=...)` or defaults to "LangGraph".
context: Static context for the graph run (e.g., `user_id`, `db_conn`).
store: Store for persistence and memory, if configured.
stream_writer: Function for writing to the custom stream.
previous: The previous return value for the given thread (functional API only).
Example:
```python
from langchain.agents.middleware import wrap_model_call, AgentRuntime
from langchain.agents.middleware.types import ModelRequest, ModelResponse
@wrap_model_call
def log_agent_name(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
'''Log which agent is making the model call.'''
agent_name = request.runtime.agent_name
print(f"Agent '{agent_name}' is calling the model")
# Access runtime context directly (flattened)
user_id = request.runtime.context.get("user_id")
print(f"User: {user_id}")
return handler(request)
```
"""
agent_name: str
"""The name of the currently executing graph/agent."""
context: ContextT = field(default=None) # type: ignore[assignment]
"""Static context for the graph run, like `user_id`, `db_conn`, etc."""
store: BaseStore | None = field(default=None)
"""Store for the graph run, enabling persistence and memory."""
stream_writer: StreamWriter = field(default=None) # type: ignore[assignment]
"""Function that writes to the custom stream."""
previous: Any = field(default=None)
"""The previous return value for the given thread."""
@classmethod
def from_runtime(cls, name: str, runtime: Runtime[ContextT]) -> AgentRuntime[ContextT]:
"""Create an AgentRuntime from a Runtime."""
return AgentRuntime[ContextT](
agent_name=name,
context=runtime.context,
store=runtime.store,
stream_writer=runtime.stream_writer,
previous=runtime.previous,
)
class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for ModelRequest.override() method."""
@@ -74,7 +144,23 @@ class _ModelRequestOverrides(TypedDict, total=False):
@dataclass
class ModelRequest:
"""Model request information for the agent."""
"""Model request information for the agent.
This dataclass contains all the information needed for a model invocation,
including the model, messages, tools, and runtime context.
Attributes:
model: The chat model to invoke.
system_prompt: Optional system prompt to prepend to messages.
messages: List of conversation messages (excluding system prompt).
tool_choice: Tool selection configuration for the model.
tools: Available tools for the model to use.
response_format: Structured output format specification.
state: Complete agent state at the time of model invocation.
runtime: Agent runtime context including agent name and underlying
LangGraph Runtime with context, store, and stream_writer.
model_settings: Additional model-specific settings.
"""
model: BaseChatModel
system_prompt: str | None
@@ -83,7 +169,7 @@ class ModelRequest:
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
runtime: AgentRuntime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
@@ -932,7 +1018,7 @@ def before_agent(
```python
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"Starting agent with {len(state['messages'])} messages")
print(f"Starting agent '{runtime.agent_name}' with {len(state['messages'])} messages")
```
With conditional jumping:

View File

@@ -0,0 +1,108 @@
"""Tests for AgentRuntime access via wrap_model_call middleware."""
import pytest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool
from langchain.agents import create_agent
from langchain.agents.middleware import wrap_model_call
from langchain.agents.middleware.types import ModelRequest
from langchain.tools import ToolRuntime
from .model import FakeToolCallingModel
@pytest.fixture
def fake_chat_model():
"""Fixture providing a fake chat model for testing."""
return GenericFakeChatModel(messages=iter([AIMessage(content="test response")]))
def test_agent_name_accessible_in_middleware(fake_chat_model):
"""Test that agent name can be accessed via middleware."""
captured_agent_name = None
@wrap_model_call
def capture_agent_name(request: ModelRequest, handler):
nonlocal captured_agent_name
captured_agent_name = request.runtime.agent_name
return handler(request)
agent = create_agent(
fake_chat_model,
tools=[],
middleware=[capture_agent_name],
name="TestAgent",
)
agent.invoke({"messages": [HumanMessage("Hello")]})
assert captured_agent_name == "TestAgent"
def test_nested_agent_name_accessible_in_tool():
"""Test that nested agent's name is accessible when agent is used in a tool."""
# Track which agent names were captured
captured_agent_names = []
@wrap_model_call
def capture_agent_name(request: ModelRequest, handler):
captured_agent_names.append(request.runtime.agent_name)
return handler(request)
# Create a nested agent that will be called from within a tool
nested_agent = create_agent(
FakeToolCallingModel(),
tools=[],
middleware=[capture_agent_name],
name="NestedAgent",
)
# Create a tool that invokes the nested agent
@tool
def call_nested_agent(query: str, runtime: ToolRuntime) -> str:
"""Tool that calls a nested agent."""
result = nested_agent.invoke({"messages": [HumanMessage(query)]})
return result["messages"][-1].content
# Create outer agent that uses the tool
outer_agent = create_agent(
FakeToolCallingModel(
tool_calls=[
[{"name": "call_nested_agent", "args": {"query": "test"}, "id": "1"}],
[],
]
),
tools=[call_nested_agent],
middleware=[capture_agent_name],
name="OuterAgent",
)
# Invoke the outer agent, which should call the tool, which calls the nested agent
outer_agent.invoke({"messages": [HumanMessage("Hello")]})
# Both agents should have captured their names
assert "OuterAgent" in captured_agent_names
assert "NestedAgent" in captured_agent_names
async def test_agent_name_accessible_in_async_middleware():
"""Test that agent name can be accessed in async middleware."""
captured_agent_name = None
@wrap_model_call
async def capture_agent_name_async(request: ModelRequest, handler):
nonlocal captured_agent_name
captured_agent_name = request.runtime.agent_name
return await handler(request)
fake_model = GenericFakeChatModel(messages=iter([AIMessage(content="async response")]))
agent = create_agent(
fake_model,
tools=[],
middleware=[capture_agent_name_async],
name="AsyncAgent",
)
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert captured_agent_name == "AsyncAgent"

View File

@@ -1351,7 +1351,7 @@ def test_public_private_state_for_custom_middleware() -> None:
class CustomMiddleware(AgentMiddleware[CustomState]):
state_schema: type[CustomState] = CustomState
def before_model(self, state: CustomState) -> dict[str, Any]:
def before_model(self, state: CustomState, runtime) -> dict[str, Any]:
assert "omit_input" not in state
assert "omit_output" in state
assert "private_state" not in state
@@ -1456,11 +1456,11 @@ def test_injected_state_in_middleware_agent() -> None:
def test_jump_to_is_ephemeral() -> None:
class MyMiddleware(AgentMiddleware):
def before_model(self, state: AgentState) -> dict[str, Any]:
def before_model(self, state: AgentState, runtime) -> dict[str, Any]:
assert "jump_to" not in state
return {"jump_to": "model"}
def after_model(self, state: AgentState) -> dict[str, Any]:
def after_model(self, state: AgentState, runtime) -> dict[str, Any]:
assert "jump_to" not in state
return {"jump_to": "model"}