Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
89d10ca1a9 new typing 2025-10-11 07:34:42 -04:00
4 changed files with 197 additions and 56 deletions

View File

@@ -20,7 +20,11 @@ from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
AgentState,
ModelCallHandler,
ModelCallResult,
ModelCallWrapper,
ModelRequest,
ModelResponse,
after_agent,
after_model,
before_agent,
@@ -28,6 +32,7 @@ from .types import (
dynamic_prompt,
hook_config,
wrap_model_call,
wrap_tool_call,
)
__all__ = [
@@ -41,9 +46,13 @@ __all__ = [
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"ModelCallHandler",
"ModelCallLimitMiddleware",
"ModelCallResult",
"ModelCallWrapper",
"ModelFallbackMiddleware",
"ModelRequest",
"ModelResponse",
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
@@ -56,4 +65,5 @@ __all__ = [
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]

View File

@@ -17,9 +17,11 @@ from typing import (
)
if TYPE_CHECKING:
from collections.abc import Awaitable
from langchain.tools.tool_node import ToolCallRequest
from langchain.tools.tool_node import (
AsyncToolCallHandler,
ToolCallHandler,
ToolCallRequest,
)
# Needed as top level import for Pydantic schema generation on AgentState
from typing import TypeAlias
@@ -43,6 +45,9 @@ __all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"ModelCallHandler",
"ModelCallResult",
"ModelCallWrapper",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
@@ -53,6 +58,7 @@ __all__ = [
"before_model",
"dynamic_prompt",
"hook_config",
"wrap_model_call",
"wrap_tool_call",
]
@@ -102,6 +108,82 @@ Middleware can return either:
"""
ModelCallHandler = Callable[[ModelRequest], ModelResponse]
"""Type alias for the handler callback passed to wrap_model_call hooks.
The handler executes the model request and returns a ModelResponse. It can be called
multiple times for retry logic or skipped entirely to short-circuit execution.
Examples:
Simple passthrough:
```python
def my_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
return handler(request)
```
Retry logic:
```python
def retry_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
for attempt in range(3):
try:
return handler(request)
except Exception:
if attempt == 2:
raise
```
"""
AsyncModelCallHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
"""Type alias for the async handler callback passed to wrap_model_call hooks.
The async handler executes the model request and returns a ModelResponse. It can be
called multiple times for retry logic or skipped entirely to short-circuit execution.
"""
ModelCallWrapper = Callable[[ModelRequest, ModelCallHandler], ModelCallResult]
"""Type alias for synchronous model call wrapper functions.
A wrapper receives a ModelRequest and a handler callback. It can modify the request,
call the handler (potentially multiple times), modify the response, or short-circuit
entirely.
Args:
request: Model request containing state, runtime, messages, tools, etc.
handler: Callback to execute the model. Can be called multiple times.
Returns:
ModelCallResult (either ModelResponse or AIMessage)
Examples:
Basic retry pattern:
```python
def retry_on_error(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
for attempt in range(3):
try:
return handler(request)
except Exception:
if attempt == 2:
raise
```
Access runtime context:
```python
def use_runtime(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
user_id = request.runtime.context.get("user_id")
# Modify request based on context
return handler(request)
```
"""
AsyncModelCallWrapper = Callable[[ModelRequest, AsyncModelCallHandler], Awaitable[ModelCallResult]]
"""Type alias for asynchronous model call wrapper functions.
A wrapper receives a ModelRequest and an async handler callback. It can modify the
request, call the handler (potentially multiple times), modify the response, or
short-circuit entirely.
"""
@dataclass
class OmitFromSchema:
"""Annotation used to mark state attributes as omitted from input or output schemas."""
@@ -195,7 +277,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
"""Intercept and control model execution via handler callback.
@@ -278,7 +360,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
"""Intercept and control async model execution via handler callback.
@@ -331,7 +413,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution for retries, monitoring, or modification.
@@ -395,7 +477,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
"""Intercept and control async tool execution via handler callback.
@@ -480,7 +562,7 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
"""Intercept model execution via handler callback."""
...
@@ -495,7 +577,7 @@ class _CallableReturningToolResponse(Protocol):
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution via handler callback."""
...
@@ -1174,7 +1256,7 @@ def dynamic_prompt(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
request.system_prompt = prompt
@@ -1195,7 +1277,7 @@ def dynamic_prompt(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
prompt = cast("str", func(request))
request.system_prompt = prompt
@@ -1204,7 +1286,7 @@ def dynamic_prompt(
async def async_wrapped_from_sync(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("str", func(request))
@@ -1337,7 +1419,7 @@ def wrap_model_call(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
handler: AsyncModelCallHandler,
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]
@@ -1358,7 +1440,7 @@ def wrap_model_call(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
handler: ModelCallHandler,
) -> ModelCallResult:
return func(request, handler)
@@ -1480,7 +1562,7 @@ def wrap_tool_call(
async def async_wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
return await func(request, handler) # type: ignore[arg-type,misc]
@@ -1501,7 +1583,7 @@ def wrap_tool_call(
def wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
return func(request, handler)

View File

@@ -8,14 +8,28 @@ from langchain_core.tools import (
tool,
)
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError
from langchain.tools.tool_node import (
AsyncToolCallHandler,
AsyncToolCallWrapper,
InjectedState,
InjectedStore,
ToolCallHandler,
ToolCallRequest,
ToolCallWrapper,
ToolInvocationError,
)
__all__ = [
"AsyncToolCallHandler",
"AsyncToolCallWrapper",
"BaseTool",
"InjectedState",
"InjectedStore",
"InjectedToolArg",
"InjectedToolCallId",
"ToolCallHandler",
"ToolCallRequest",
"ToolCallWrapper",
"ToolException",
"ToolInvocationError",
"tool",

View File

@@ -118,26 +118,55 @@ class ToolCallRequest:
tool_call: ToolCall
tool: BaseTool
state: Any
runtime: Any
runtime: Any # Runtime[Any] | None, but using Any for simplicity and to avoid circular imports
ToolCallWrapper = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
ToolMessage | Command,
]
"""Wrapper for tool call execution with multi-call support.
ToolCallHandler = Callable[[ToolCallRequest], ToolMessage | Command]
"""Type alias for the handler callback passed to wrap_tool_call hooks.
Wrapper receives:
request: ToolCallRequest with tool_call, tool, state, and runtime.
execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).
The handler executes the tool call and returns a ToolMessage or Command. It can be
called multiple times for retry logic or skipped entirely to short-circuit execution.
Examples:
Simple passthrough:
```python
def my_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
return handler(request)
```
Retry logic:
```python
def retry_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
for attempt in range(3):
try:
return handler(request)
except Exception:
if attempt == 2:
raise
```
"""
AsyncToolCallHandler = Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]
"""Type alias for the async handler callback passed to wrap_tool_call hooks.
The async handler executes the tool call and returns a ToolMessage or Command. It can
be called multiple times for retry logic or skipped entirely to short-circuit execution.
"""
ToolCallWrapper = Callable[[ToolCallRequest, ToolCallHandler], ToolMessage | Command]
"""Type alias for synchronous tool call wrapper functions.
A wrapper receives a ToolCallRequest and a handler callback. It can modify the request,
call the handler (potentially multiple times), modify the response, or short-circuit
entirely.
Args:
request: Tool call request with tool_call, tool, state, and runtime.
handler: Callback to execute the tool. Can be called multiple times.
Returns:
ToolMessage or Command (the final result).
The execute callable can be invoked multiple times for retry logic,
with potentially modified requests each time. Each call to execute
is independent and stateless.
Note:
When implementing middleware for `create_agent`, use
`AgentMiddleware.wrap_tool_call` which provides properly typed
@@ -145,55 +174,61 @@ Note:
Examples:
Passthrough (execute once):
def handler(request, execute):
return execute(request)
```python
def passthrough(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
return handler(request)
```
Modify request before execution:
def handler(request, execute):
```python
def modify_args(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
request.tool_call["args"]["value"] *= 2
return execute(request)
return handler(request)
```
Retry on error (execute multiple times):
def handler(request, execute):
```python
def retry_on_error(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
for attempt in range(3):
try:
result = execute(request)
result = handler(request)
if is_valid(result):
return result
except Exception:
if attempt == 2:
raise
return result
```
Conditional retry based on response:
Access runtime context:
```python
def use_runtime(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
if request.runtime is not None:
thread_id = request.runtime.context.get("thread_id")
# Use runtime context
return handler(request)
```
def handler(request, execute):
for attempt in range(3):
result = execute(request)
if isinstance(result, ToolMessage) and result.status != "error":
return result
if attempt < 2:
continue
return result
Cache/short-circuit without calling execute:
def handler(request, execute):
Cache/short-circuit without calling handler:
```python
def with_cache(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
if cached := get_cache(request):
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
result = execute(request)
result = handler(request)
save_cache(request, result)
return result
```
"""
AsyncToolCallWrapper = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, AsyncToolCallHandler], Awaitable[ToolMessage | Command]
]
"""Async wrapper for tool call execution with multi-call support."""
"""Type alias for asynchronous tool call wrapper functions.
A wrapper receives a ToolCallRequest and an async handler callback. It can modify the
request, call the handler (potentially multiple times), modify the response, or
short-circuit entirely.
"""
class ToolCallWithContext(TypedDict):