mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-20 05:04:50 +00:00
Compare commits
1 Commits
jacob/meta
...
sr/more-ty
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7023b6f496 |
@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
PrivateStateAttr,
|
||||
ResponseT,
|
||||
hook_config,
|
||||
|
||||
@@ -33,9 +33,13 @@ from langchain_core.messages import (
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph._internal._typing import StateLike
|
||||
|
||||
ContextT = TypeVar("ContextT", bound="StateLike | None", default=Any)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
@@ -90,7 +94,7 @@ class ModelRequest(Generic[ContextT]):
|
||||
"""Model request information for the agent.
|
||||
|
||||
Type Parameters:
|
||||
ContextT: The type of the runtime context. Defaults to `None` if not specified.
|
||||
ContextT: The type of the runtime context. Defaults to `Any` if not specified.
|
||||
"""
|
||||
|
||||
model: BaseChatModel
|
||||
@@ -385,7 +389,7 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
|
||||
Type Parameters:
|
||||
StateT: The type of the agent state. Defaults to `AgentState[Any]`.
|
||||
ContextT: The type of the runtime context. Defaults to `None`.
|
||||
ContextT: The type of the runtime context. Defaults to `Any`.
|
||||
ResponseT: The type of the structured response. Defaults to `Any`.
|
||||
"""
|
||||
|
||||
@@ -850,6 +854,38 @@ class _CallableReturningToolResponse(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class _AsyncCallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc]
|
||||
"""Async callable for model call interception with async handler callback.
|
||||
|
||||
Receives async handler callback to execute model and returns awaitable
|
||||
`ModelResponse` or `AIMessage`.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]:
|
||||
"""Intercept async model execution via handler callback."""
|
||||
...
|
||||
|
||||
|
||||
class _AsyncCallableReturningToolResponse(Protocol):
|
||||
"""Async callable for tool call interception with async handler callback.
|
||||
|
||||
Receives async handler callback to execute tool and returns awaitable
|
||||
`ToolMessage` or `Command`.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> Awaitable[ToolMessage | Command[Any]]:
|
||||
"""Intercept async tool execution via handler callback."""
|
||||
...
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@@ -1733,6 +1769,12 @@ def dynamic_prompt(
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_model_call(
|
||||
func: _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
@@ -1747,20 +1789,28 @@ def wrap_model_call(
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
|
||||
[
|
||||
_CallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None,
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
| None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
|
||||
[
|
||||
_CallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
@@ -1841,18 +1891,24 @@ def wrap_model_call(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
# iscoroutinefunction is not a TypeGuard, so narrow manually
|
||||
_async_func = cast(
|
||||
"_AsyncCallableReturningModelResponse[StateT, ContextT, ResponseT]",
|
||||
func,
|
||||
)
|
||||
|
||||
async def async_wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
return await _async_func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
||||
@@ -1868,12 +1924,14 @@ def wrap_model_call(
|
||||
},
|
||||
)()
|
||||
|
||||
_sync_func = cast("_CallableReturningModelResponse[StateT, ContextT, ResponseT]", func)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
return func(request, handler)
|
||||
return _sync_func(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
||||
|
||||
@@ -1892,6 +1950,12 @@ def wrap_model_call(
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_tool_call(
|
||||
func: _AsyncCallableReturningToolResponse,
|
||||
) -> AgentMiddleware: ...
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_tool_call(
|
||||
func: _CallableReturningToolResponse,
|
||||
@@ -1905,19 +1969,19 @@ def wrap_tool_call(
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningToolResponse],
|
||||
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
|
||||
AgentMiddleware,
|
||||
]: ...
|
||||
|
||||
|
||||
def wrap_tool_call(
|
||||
func: _CallableReturningToolResponse | None = None,
|
||||
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse | None = None,
|
||||
*,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningToolResponse],
|
||||
[_CallableReturningToolResponse | _AsyncCallableReturningToolResponse],
|
||||
AgentMiddleware,
|
||||
]
|
||||
| AgentMiddleware
|
||||
@@ -2001,18 +2065,20 @@ def wrap_tool_call(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningToolResponse,
|
||||
func: _CallableReturningToolResponse | _AsyncCallableReturningToolResponse,
|
||||
) -> AgentMiddleware:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
# iscoroutinefunction is not a TypeGuard, so narrow manually
|
||||
_async_func = cast("_AsyncCallableReturningToolResponse", func)
|
||||
|
||||
async def async_wrapped(
|
||||
_self: AgentMiddleware,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await func(request, handler) # type: ignore[arg-type,misc]
|
||||
return await _async_func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
||||
@@ -2028,12 +2094,14 @@ def wrap_tool_call(
|
||||
},
|
||||
)()
|
||||
|
||||
_sync_func = cast("_CallableReturningToolResponse", func)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return func(request, handler)
|
||||
return _sync_func(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
class OldStyleMiddleware1(AgentMiddleware):
|
||||
"""Middleware with no type parameters at all - most common old pattern."""
|
||||
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
|
||||
# Simple middleware that just logs or does something
|
||||
return None
|
||||
|
||||
@@ -104,7 +104,7 @@ class OldStyleMiddleware4(AgentMiddleware[AgentState[Any], MyContext]):
|
||||
# OLD PATTERN 5: Decorator-based middleware
|
||||
# =============================================================================
|
||||
@before_model
|
||||
def old_style_decorator(state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
|
||||
def old_style_decorator(state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
|
||||
"""Decorator middleware - old pattern."""
|
||||
return None
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ Expected errors:
|
||||
1. TypedDict "UserContext" has no key "session_id" - accessing wrong context field
|
||||
2. Argument incompatible with supertype - mismatched ModelRequest type
|
||||
3. Cannot infer value of type parameter - middleware/context_schema mismatch
|
||||
4. "AnalysisResult" has no attribute "summary" - accessing wrong response field
|
||||
5. Handler returns wrong ResponseT type
|
||||
4. (No longer an error) Backwards compatible middleware with context_schema is OK
|
||||
because ContextT defaults to Any, which is compatible with any context_schema
|
||||
5. "AnalysisResult" has no attribute "summary" - accessing wrong response field
|
||||
6. Handler returns wrong ResponseT type
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -107,7 +109,9 @@ def test_mismatched_context_schema() -> None:
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ERROR 4: Backwards compatible middleware with typed context_schema
|
||||
# NOTE 4: Backwards compatible middleware with typed context_schema
|
||||
# This is NOT a type error because ContextT defaults to Any, which is
|
||||
# compatible with any context_schema.
|
||||
# =============================================================================
|
||||
class BackwardsCompatibleMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
@@ -119,10 +123,10 @@ class BackwardsCompatibleMiddleware(AgentMiddleware):
|
||||
|
||||
|
||||
def test_backwards_compat_with_context_schema() -> None:
|
||||
# TYPE ERROR: BackwardsCompatibleMiddleware is AgentMiddleware[..., None]
|
||||
# but context_schema=UserContext expects AgentMiddleware[..., UserContext]
|
||||
# BackwardsCompatibleMiddleware uses default type params (Any),
|
||||
# so it is compatible with any context_schema including UserContext
|
||||
fake_model = FakeToolCallingModel()
|
||||
_agent = create_agent( # type: ignore[misc]
|
||||
_agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[BackwardsCompatibleMiddleware()],
|
||||
context_schema=UserContext,
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -31,13 +31,17 @@ from langchain.agents.middleware.types import (
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
before_model,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -78,7 +82,7 @@ class SummaryResult(BaseModel):
|
||||
class BackwardsCompatibleMiddleware(AgentMiddleware):
|
||||
"""Middleware that doesn't specify type parameters - backwards compatible."""
|
||||
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[None]) -> dict[str, Any] | None:
|
||||
def before_model(self, state: AgentState[Any], runtime: Runtime[Any]) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
def wrap_model_call(
|
||||
@@ -103,7 +107,7 @@ class BackwardsCompatibleMiddleware2(AgentMiddleware):
|
||||
|
||||
@before_model
|
||||
def backwards_compatible_decorator(
|
||||
state: AgentState[Any], runtime: Runtime[None]
|
||||
state: AgentState[Any], runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Decorator middleware without explicit type parameters."""
|
||||
return None
|
||||
@@ -237,7 +241,7 @@ def fake_model() -> GenericFakeChatModel:
|
||||
|
||||
|
||||
def test_create_agent_no_context_schema(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Backwards compatible: No context_schema means ContextT=None."""
|
||||
"""Backwards compatible: No context_schema means ContextT=Any."""
|
||||
agent: CompiledStateGraph[Any, None, Any, Any] = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[
|
||||
@@ -441,3 +445,148 @@ def test_model_response_backwards_compatible() -> None:
|
||||
)
|
||||
|
||||
assert response.structured_response is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 9. ASYNC DECORATOR VARIANTS FOR wrap_model_call AND wrap_tool_call
|
||||
# =============================================================================
|
||||
@wrap_model_call
|
||||
async def async_wrap_model_retry(
|
||||
request: ModelRequest[UserContext],
|
||||
handler: Callable[[ModelRequest[UserContext]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
"""Async wrap_model_call decorator should type-check correctly."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
return await handler(request)
|
||||
|
||||
|
||||
@wrap_model_call
|
||||
async def async_wrap_model_unparameterized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse | AIMessage:
|
||||
"""Async wrap_model_call with unparameterized types (backwards compat)."""
|
||||
return await handler(request)
|
||||
|
||||
|
||||
@wrap_tool_call
|
||||
async def async_wrap_tool_retry(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Async wrap_tool_call decorator should type-check correctly."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
return await handler(request)
|
||||
|
||||
|
||||
def test_async_wrap_model_call_decorator(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Async @wrap_model_call decorator produces valid AgentMiddleware."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[async_wrap_model_retry],
|
||||
context_schema=UserContext,
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_async_wrap_model_call_unparameterized(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Async @wrap_model_call with unparameterized types works."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[async_wrap_model_unparameterized],
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_async_wrap_tool_call_decorator(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Async @wrap_tool_call decorator produces valid AgentMiddleware."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[async_wrap_tool_retry],
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
# Test sync decorators still work (regression check)
|
||||
@wrap_model_call
|
||||
def sync_wrap_model(
|
||||
request: ModelRequest[UserContext],
|
||||
handler: Callable[[ModelRequest[UserContext]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
"""Sync wrap_model_call should still type-check correctly."""
|
||||
return handler(request)
|
||||
|
||||
|
||||
@wrap_tool_call
|
||||
def sync_wrap_tool(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Sync wrap_tool_call should still type-check correctly."""
|
||||
return handler(request)
|
||||
|
||||
|
||||
def test_sync_wrap_model_call_decorator(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Sync @wrap_model_call decorator still works."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[sync_wrap_model],
|
||||
context_schema=UserContext,
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_sync_wrap_tool_call_decorator(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Sync @wrap_tool_call decorator still works."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[sync_wrap_tool],
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
# Test with func=None pattern (parenthesized decorator)
|
||||
@wrap_model_call()
|
||||
async def async_wrap_model_with_parens(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse | AIMessage:
|
||||
"""Async wrap_model_call with parentheses should type-check correctly."""
|
||||
return await handler(request)
|
||||
|
||||
|
||||
@wrap_tool_call()
|
||||
async def async_wrap_tool_with_parens(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
"""Async wrap_tool_call with parentheses should type-check correctly."""
|
||||
return await handler(request)
|
||||
|
||||
|
||||
def test_async_wrap_model_call_with_parens(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Async @wrap_model_call() with parens produces valid AgentMiddleware."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[async_wrap_model_with_parens],
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
def test_async_wrap_tool_call_with_parens(fake_model: GenericFakeChatModel) -> None:
|
||||
"""Async @wrap_tool_call() with parens produces valid AgentMiddleware."""
|
||||
agent = create_agent(
|
||||
model=fake_model,
|
||||
middleware=[async_wrap_tool_with_parens],
|
||||
)
|
||||
assert agent is not None
|
||||
|
||||
Reference in New Issue
Block a user