mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 16:20:16 +00:00
Compare commits
1 Commits
mdrxy/tool
...
sr/fixing-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89d10ca1a9 |
@@ -20,7 +20,11 @@ from .tool_selection import LLMToolSelectorMiddleware
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCallHandler,
|
||||
ModelCallResult,
|
||||
ModelCallWrapper,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -28,6 +32,7 @@ from .types import (
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -41,9 +46,13 @@ __all__ = [
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelCallHandler",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelCallResult",
|
||||
"ModelCallWrapper",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
@@ -56,4 +65,5 @@ __all__ = [
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"wrap_model_call",
|
||||
"wrap_tool_call",
|
||||
]
|
||||
|
||||
@@ -17,9 +17,11 @@ from typing import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
from langchain.tools.tool_node import (
|
||||
AsyncToolCallHandler,
|
||||
ToolCallHandler,
|
||||
ToolCallRequest,
|
||||
)
|
||||
|
||||
# Needed as top level import for Pydantic schema generation on AgentState
|
||||
from typing import TypeAlias
|
||||
@@ -43,6 +45,9 @@ __all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ModelCallHandler",
|
||||
"ModelCallResult",
|
||||
"ModelCallWrapper",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
@@ -53,6 +58,7 @@ __all__ = [
|
||||
"before_model",
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"wrap_model_call",
|
||||
"wrap_tool_call",
|
||||
]
|
||||
|
||||
@@ -102,6 +108,82 @@ Middleware can return either:
|
||||
"""
|
||||
|
||||
|
||||
ModelCallHandler = Callable[[ModelRequest], ModelResponse]
|
||||
"""Type alias for the handler callback passed to wrap_model_call hooks.
|
||||
|
||||
The handler executes the model request and returns a ModelResponse. It can be called
|
||||
multiple times for retry logic or skipped entirely to short-circuit execution.
|
||||
|
||||
Examples:
|
||||
Simple passthrough:
|
||||
```python
|
||||
def my_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
Retry logic:
|
||||
```python
|
||||
def retry_wrapper(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
```
|
||||
"""
|
||||
|
||||
AsyncModelCallHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]
|
||||
"""Type alias for the async handler callback passed to wrap_model_call hooks.
|
||||
|
||||
The async handler executes the model request and returns a ModelResponse. It can be
|
||||
called multiple times for retry logic or skipped entirely to short-circuit execution.
|
||||
"""
|
||||
|
||||
ModelCallWrapper = Callable[[ModelRequest, ModelCallHandler], ModelCallResult]
|
||||
"""Type alias for synchronous model call wrapper functions.
|
||||
|
||||
A wrapper receives a ModelRequest and a handler callback. It can modify the request,
|
||||
call the handler (potentially multiple times), modify the response, or short-circuit
|
||||
entirely.
|
||||
|
||||
Args:
|
||||
request: Model request containing state, runtime, messages, tools, etc.
|
||||
handler: Callback to execute the model. Can be called multiple times.
|
||||
|
||||
Returns:
|
||||
ModelCallResult (either ModelResponse or AIMessage)
|
||||
|
||||
Examples:
|
||||
Basic retry pattern:
|
||||
```python
|
||||
def retry_on_error(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
```
|
||||
|
||||
Access runtime context:
|
||||
```python
|
||||
def use_runtime(request: ModelRequest, handler: ModelCallHandler) -> ModelCallResult:
|
||||
user_id = request.runtime.context.get("user_id")
|
||||
# Modify request based on context
|
||||
return handler(request)
|
||||
```
|
||||
"""
|
||||
|
||||
AsyncModelCallWrapper = Callable[[ModelRequest, AsyncModelCallHandler], Awaitable[ModelCallResult]]
|
||||
"""Type alias for asynchronous model call wrapper functions.
|
||||
|
||||
A wrapper receives a ModelRequest and an async handler callback. It can modify the
|
||||
request, call the handler (potentially multiple times), modify the response, or
|
||||
short-circuit entirely.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmitFromSchema:
|
||||
"""Annotation used to mark state attributes as omitted from input or output schemas."""
|
||||
@@ -195,7 +277,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
handler: ModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
@@ -278,7 +360,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
handler: AsyncModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control async model execution via handler callback.
|
||||
|
||||
@@ -331,7 +413,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution for retries, monitoring, or modification.
|
||||
|
||||
@@ -395,7 +477,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
handler: AsyncToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept and control async tool execution via handler callback.
|
||||
|
||||
@@ -480,7 +562,7 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
handler: ModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
@@ -495,7 +577,7 @@ class _CallableReturningToolResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution via handler callback."""
|
||||
...
|
||||
@@ -1174,7 +1256,7 @@ def dynamic_prompt(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
handler: AsyncModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
request.system_prompt = prompt
|
||||
@@ -1195,7 +1277,7 @@ def dynamic_prompt(
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
handler: ModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
prompt = cast("str", func(request))
|
||||
request.system_prompt = prompt
|
||||
@@ -1204,7 +1286,7 @@ def dynamic_prompt(
|
||||
async def async_wrapped_from_sync(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
handler: AsyncModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
# Delegate to sync function
|
||||
prompt = cast("str", func(request))
|
||||
@@ -1337,7 +1419,7 @@ def wrap_model_call(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
handler: AsyncModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
|
||||
@@ -1358,7 +1440,7 @@ def wrap_model_call(
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
handler: ModelCallHandler,
|
||||
) -> ModelCallResult:
|
||||
return func(request, handler)
|
||||
|
||||
@@ -1480,7 +1562,7 @@ def wrap_tool_call(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware, # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
handler: AsyncToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
return await func(request, handler) # type: ignore[arg-type,misc]
|
||||
|
||||
@@ -1501,7 +1583,7 @@ def wrap_tool_call(
|
||||
def wrapped(
|
||||
self: AgentMiddleware, # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
return func(request, handler)
|
||||
|
||||
|
||||
@@ -8,14 +8,28 @@ from langchain_core.tools import (
|
||||
tool,
|
||||
)
|
||||
|
||||
from langchain.tools.tool_node import InjectedState, InjectedStore, ToolInvocationError
|
||||
from langchain.tools.tool_node import (
|
||||
AsyncToolCallHandler,
|
||||
AsyncToolCallWrapper,
|
||||
InjectedState,
|
||||
InjectedStore,
|
||||
ToolCallHandler,
|
||||
ToolCallRequest,
|
||||
ToolCallWrapper,
|
||||
ToolInvocationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AsyncToolCallHandler",
|
||||
"AsyncToolCallWrapper",
|
||||
"BaseTool",
|
||||
"InjectedState",
|
||||
"InjectedStore",
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"ToolCallHandler",
|
||||
"ToolCallRequest",
|
||||
"ToolCallWrapper",
|
||||
"ToolException",
|
||||
"ToolInvocationError",
|
||||
"tool",
|
||||
|
||||
@@ -118,26 +118,55 @@ class ToolCallRequest:
|
||||
tool_call: ToolCall
|
||||
tool: BaseTool
|
||||
state: Any
|
||||
runtime: Any
|
||||
runtime: Any # Runtime[Any] | None, but using Any for simplicity and to avoid circular imports
|
||||
|
||||
|
||||
ToolCallWrapper = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
|
||||
ToolMessage | Command,
|
||||
]
|
||||
"""Wrapper for tool call execution with multi-call support.
|
||||
ToolCallHandler = Callable[[ToolCallRequest], ToolMessage | Command]
|
||||
"""Type alias for the handler callback passed to wrap_tool_call hooks.
|
||||
|
||||
Wrapper receives:
|
||||
request: ToolCallRequest with tool_call, tool, state, and runtime.
|
||||
execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).
|
||||
The handler executes the tool call and returns a ToolMessage or Command. It can be
|
||||
called multiple times for retry logic or skipped entirely to short-circuit execution.
|
||||
|
||||
Examples:
|
||||
Simple passthrough:
|
||||
```python
|
||||
def my_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
Retry logic:
|
||||
```python
|
||||
def retry_wrapper(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
```
|
||||
"""
|
||||
|
||||
AsyncToolCallHandler = Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]
|
||||
"""Type alias for the async handler callback passed to wrap_tool_call hooks.
|
||||
|
||||
The async handler executes the tool call and returns a ToolMessage or Command. It can
|
||||
be called multiple times for retry logic or skipped entirely to short-circuit execution.
|
||||
"""
|
||||
|
||||
ToolCallWrapper = Callable[[ToolCallRequest, ToolCallHandler], ToolMessage | Command]
|
||||
"""Type alias for synchronous tool call wrapper functions.
|
||||
|
||||
A wrapper receives a ToolCallRequest and a handler callback. It can modify the request,
|
||||
call the handler (potentially multiple times), modify the response, or short-circuit
|
||||
entirely.
|
||||
|
||||
Args:
|
||||
request: Tool call request with tool_call, tool, state, and runtime.
|
||||
handler: Callback to execute the tool. Can be called multiple times.
|
||||
|
||||
Returns:
|
||||
ToolMessage or Command (the final result).
|
||||
|
||||
The execute callable can be invoked multiple times for retry logic,
|
||||
with potentially modified requests each time. Each call to execute
|
||||
is independent and stateless.
|
||||
|
||||
Note:
|
||||
When implementing middleware for `create_agent`, use
|
||||
`AgentMiddleware.wrap_tool_call` which provides properly typed
|
||||
@@ -145,55 +174,61 @@ Note:
|
||||
|
||||
Examples:
|
||||
Passthrough (execute once):
|
||||
|
||||
def handler(request, execute):
|
||||
return execute(request)
|
||||
```python
|
||||
def passthrough(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
Modify request before execution:
|
||||
|
||||
def handler(request, execute):
|
||||
```python
|
||||
def modify_args(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
request.tool_call["args"]["value"] *= 2
|
||||
return execute(request)
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
Retry on error (execute multiple times):
|
||||
|
||||
def handler(request, execute):
|
||||
```python
|
||||
def retry_on_error(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = execute(request)
|
||||
result = handler(request)
|
||||
if is_valid(result):
|
||||
return result
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
return result
|
||||
```
|
||||
|
||||
Conditional retry based on response:
|
||||
Access runtime context:
|
||||
```python
|
||||
def use_runtime(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
if request.runtime is not None:
|
||||
thread_id = request.runtime.context.get("thread_id")
|
||||
# Use runtime context
|
||||
return handler(request)
|
||||
```
|
||||
|
||||
def handler(request, execute):
|
||||
for attempt in range(3):
|
||||
result = execute(request)
|
||||
if isinstance(result, ToolMessage) and result.status != "error":
|
||||
return result
|
||||
if attempt < 2:
|
||||
continue
|
||||
return result
|
||||
|
||||
Cache/short-circuit without calling execute:
|
||||
|
||||
def handler(request, execute):
|
||||
Cache/short-circuit without calling handler:
|
||||
```python
|
||||
def with_cache(request: ToolCallRequest, handler: ToolCallHandler) -> ToolMessage | Command:
|
||||
if cached := get_cache(request):
|
||||
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
||||
result = execute(request)
|
||||
result = handler(request)
|
||||
save_cache(request, result)
|
||||
return result
|
||||
```
|
||||
"""
|
||||
|
||||
AsyncToolCallWrapper = Callable[
|
||||
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
||||
Awaitable[ToolMessage | Command],
|
||||
[ToolCallRequest, AsyncToolCallHandler], Awaitable[ToolMessage | Command]
|
||||
]
|
||||
"""Async wrapper for tool call execution with multi-call support."""
|
||||
"""Type alias for asynchronous tool call wrapper functions.
|
||||
|
||||
A wrapper receives a ToolCallRequest and an async handler callback. It can modify the
|
||||
request, call the handler (potentially multiple times), modify the response, or
|
||||
short-circuit entirely.
|
||||
"""
|
||||
|
||||
|
||||
class ToolCallWithContext(TypedDict):
|
||||
|
||||
Reference in New Issue
Block a user