mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 22:32:33 +00:00
Compare commits
4 Commits
mdrxy/open
...
sr/multipl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74463b299e | ||
|
|
20e5fd4186 | ||
|
|
46ad97c297 | ||
|
|
73d9061764 |
@@ -86,13 +86,19 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
],
|
||||
ModelResponse | AIMessage,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
],
|
||||
ModelResponse,
|
||||
]
|
||||
| None
|
||||
@@ -140,8 +146,8 @@ def _chain_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
result = single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -150,25 +156,34 @@ def _chain_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
],
|
||||
ModelResponse,
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
def inner_handler(req: ModelRequest[Any, ContextT]) -> ModelResponse:
|
||||
inner_result = inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -185,8 +200,8 @@ def _chain_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = result(request, handler)
|
||||
@@ -198,13 +213,19 @@ def _chain_model_call_handlers(
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
| None
|
||||
@@ -225,8 +246,8 @@ def _chain_async_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
async def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
result = await single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -235,25 +256,34 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[
|
||||
ModelRequest[Any, ContextT],
|
||||
Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
],
|
||||
Awaitable[ModelResponse],
|
||||
]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
async def inner_handler(req: ModelRequest[Any, ContextT]) -> ModelResponse:
|
||||
inner_result = await inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -270,8 +300,8 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
async def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[Any, ContextT],
|
||||
handler: Callable[[ModelRequest[Any, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = await result(request, handler)
|
||||
@@ -973,7 +1003,9 @@ def create_agent(
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
def _get_bound_model(
|
||||
request: ModelRequest[Any, ContextT],
|
||||
) -> tuple[Runnable, ResponseFormat | None]:
|
||||
"""Get the model with appropriate tool bindings.
|
||||
|
||||
Performs auto-detection of strategy if needed based on model capabilities.
|
||||
@@ -1087,7 +1119,7 @@ def create_agent(
|
||||
)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
|
||||
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
||||
def _execute_model_sync(request: ModelRequest[Any, ContextT]) -> ModelResponse:
|
||||
"""Execute model and return response.
|
||||
|
||||
This is the core model execution logic wrapped by `wrap_model_call` handlers.
|
||||
@@ -1140,7 +1172,7 @@ def create_agent(
|
||||
|
||||
return state_updates
|
||||
|
||||
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
||||
async def _execute_model_async(request: ModelRequest[Any, ContextT]) -> ModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
|
||||
This is the core async model execution logic wrapped by `wrap_model_call`
|
||||
|
||||
@@ -67,7 +67,9 @@ __all__ = [
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
"""Destination to jump to when a middleware node returns."""
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
ResponseT = TypeVar("ResponseT", default=Any)
|
||||
# StateT uses string forward references since AgentState is defined later
|
||||
StateT = TypeVar("StateT", bound="AgentState", default="AgentState")
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
@@ -83,7 +85,7 @@ class _ModelRequestOverrides(TypedDict, total=False):
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class ModelRequest:
|
||||
class ModelRequest(Generic[StateT, ContextT]):
|
||||
"""Model request information for the agent."""
|
||||
|
||||
model: BaseChatModel
|
||||
@@ -92,8 +94,8 @@ class ModelRequest:
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
state: AgentState
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
state: StateT
|
||||
runtime: Runtime[ContextT]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
@@ -106,7 +108,7 @@ class ModelRequest:
|
||||
tool_choice: Any | None = None,
|
||||
tools: list[BaseTool | dict] | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
state: AgentState | None = None,
|
||||
state: StateT | None = None,
|
||||
runtime: Runtime[ContextT] | None = None,
|
||||
model_settings: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
@@ -140,7 +142,7 @@ class ModelRequest:
|
||||
self.tool_choice = tool_choice
|
||||
self.tools = tools if tools is not None else []
|
||||
self.response_format = response_format
|
||||
self.state = state if state is not None else {"messages": []}
|
||||
self.state = state if state is not None else cast("StateT", {"messages": []})
|
||||
self.runtime = runtime # type: ignore[assignment]
|
||||
self.model_settings = model_settings if model_settings is not None else {}
|
||||
|
||||
@@ -189,7 +191,9 @@ class ModelRequest:
|
||||
)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
def override(
|
||||
self, **overrides: Unpack[_ModelRequestOverrides]
|
||||
) -> ModelRequest[StateT, ContextT]:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
||||
@@ -322,7 +326,6 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049
|
||||
structured_response: NotRequired[ResponseT]
|
||||
|
||||
|
||||
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
||||
StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True)
|
||||
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
||||
|
||||
@@ -383,8 +386,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
@@ -478,8 +481,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control async model execution via handler callback.
|
||||
|
||||
@@ -698,18 +701,24 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
||||
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
||||
class _SyncCallableReturningSystemMessage(Protocol[StateT, ContextT]):
|
||||
"""Sync callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest
|
||||
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
|
||||
def __call__(self, request: ModelRequest[StateT, ContextT]) -> str | SystemMessage:
|
||||
"""Generate a system prompt string or SystemMessage based on the request."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
||||
"""Callable for model call interception with handler callback.
|
||||
class _AsyncCallableReturningSystemMessage(Protocol[StateT, ContextT]):
|
||||
"""Async callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
||||
|
||||
def __call__(self, request: ModelRequest[StateT, ContextT]) -> Awaitable[str | SystemMessage]:
|
||||
"""Generate a system prompt string or SystemMessage based on the request."""
|
||||
...
|
||||
|
||||
|
||||
class _SyncCallableReturningModelResponse(Protocol[StateT, ContextT]):
|
||||
"""Sync callable for model call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute model and returns `ModelResponse` or
|
||||
`AIMessage`.
|
||||
@@ -717,15 +726,31 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningToolResponse(Protocol):
|
||||
"""Callable for tool call interception with handler callback.
|
||||
class _AsyncCallableReturningModelResponse(Protocol[StateT, ContextT]):
|
||||
"""Async callable for model call interception with handler callback.
|
||||
|
||||
Receives async handler callback to execute model and returns `ModelResponse` or
|
||||
`AIMessage`.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> Awaitable[ModelCallResult]:
|
||||
"""Intercept model execution via async handler callback."""
|
||||
...
|
||||
|
||||
|
||||
class _SyncCallableReturningToolResponse(Protocol):
|
||||
"""Sync callable for tool call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute tool and returns final `ToolMessage` or
|
||||
`Command`.
|
||||
@@ -740,6 +765,22 @@ class _CallableReturningToolResponse(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class _AsyncCallableReturningToolResponse(Protocol):
|
||||
"""Async callable for tool call interception with handler callback.
|
||||
|
||||
Receives async handler callback to execute tool and returns final `ToolMessage` or
|
||||
`Command`.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> Awaitable[ToolMessage | Command]:
|
||||
"""Intercept tool execution via async handler callback."""
|
||||
...
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@@ -1385,24 +1426,44 @@ def after_agent(
|
||||
|
||||
@overload
|
||||
def dynamic_prompt(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT],
|
||||
func: _SyncCallableReturningSystemMessage[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def dynamic_prompt(
|
||||
func: _AsyncCallableReturningSystemMessage[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def dynamic_prompt(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningSystemMessage[StateT, ContextT]],
|
||||
[
|
||||
_SyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def dynamic_prompt(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
|
||||
func: (
|
||||
_SyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
| None
|
||||
) = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningSystemMessage[StateT, ContextT]],
|
||||
[
|
||||
_SyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
@@ -1418,6 +1479,9 @@ def dynamic_prompt(
|
||||
|
||||
Must accept: `request: ModelRequest` - Model request (contains state and
|
||||
runtime)
|
||||
state_schema: Optional custom state schema type.
|
||||
|
||||
If not provided, uses the default `AgentState` schema.
|
||||
|
||||
Returns:
|
||||
Either an `AgentMiddleware` instance (if func is provided) or a decorator
|
||||
@@ -1456,18 +1520,22 @@ def dynamic_prompt(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT],
|
||||
func: (
|
||||
_SyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
| _AsyncCallableReturningSystemMessage[StateT, ContextT]
|
||||
),
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
async_func = cast("_AsyncCallableReturningSystemMessage[StateT, ContextT]", func)
|
||||
|
||||
async def async_wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
prompt = await async_func(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1480,18 +1548,20 @@ def dynamic_prompt(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": [],
|
||||
"awrap_model_call": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
sync_func = cast("_SyncCallableReturningSystemMessage[StateT, ContextT]", func)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
||||
prompt = sync_func(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1500,11 +1570,11 @@ def dynamic_prompt(
|
||||
|
||||
async def async_wrapped_from_sync(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
# Delegate to sync function
|
||||
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
|
||||
prompt = sync_func(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1517,7 +1587,7 @@ def dynamic_prompt(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": [],
|
||||
"wrap_model_call": wrapped,
|
||||
"awrap_model_call": async_wrapped_from_sync,
|
||||
@@ -1531,7 +1601,13 @@ def dynamic_prompt(
|
||||
|
||||
@overload
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT],
|
||||
func: _SyncCallableReturningModelResponse[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_model_call(
|
||||
func: _AsyncCallableReturningModelResponse[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@@ -1543,20 +1619,30 @@ def wrap_model_call(
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT]],
|
||||
[
|
||||
_SyncCallableReturningModelResponse[StateT, ContextT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
|
||||
func: (
|
||||
_SyncCallableReturningModelResponse[StateT, ContextT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT]
|
||||
| None
|
||||
) = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT]],
|
||||
[
|
||||
_SyncCallableReturningModelResponse[StateT, ContextT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT]
|
||||
],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
@@ -1637,18 +1723,22 @@ def wrap_model_call(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT],
|
||||
func: (
|
||||
_SyncCallableReturningModelResponse[StateT, ContextT]
|
||||
| _AsyncCallableReturningModelResponse[StateT, ContextT]
|
||||
),
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
async_func = cast("_AsyncCallableReturningModelResponse[StateT, ContextT]", func)
|
||||
|
||||
async def async_wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
return await async_func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
||||
@@ -1664,12 +1754,14 @@ def wrap_model_call(
|
||||
},
|
||||
)()
|
||||
|
||||
sync_func = cast("_SyncCallableReturningModelResponse[StateT, ContextT]", func)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: Callable[[ModelRequest[StateT, ContextT]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return func(request, handler)
|
||||
return sync_func(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
||||
|
||||
@@ -1690,8 +1782,14 @@ def wrap_model_call(
|
||||
|
||||
@overload
|
||||
def wrap_tool_call(
|
||||
func: _CallableReturningToolResponse,
|
||||
) -> AgentMiddleware: ...
|
||||
func: _SyncCallableReturningToolResponse,
|
||||
) -> AgentMiddleware[AgentState, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def wrap_tool_call(
|
||||
func: _AsyncCallableReturningToolResponse,
|
||||
) -> AgentMiddleware[AgentState, None]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -1701,22 +1799,22 @@ def wrap_tool_call(
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningToolResponse],
|
||||
AgentMiddleware,
|
||||
[_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse],
|
||||
AgentMiddleware[AgentState, None],
|
||||
]: ...
|
||||
|
||||
|
||||
def wrap_tool_call(
|
||||
func: _CallableReturningToolResponse | None = None,
|
||||
func: _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse | None = None,
|
||||
*,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningToolResponse],
|
||||
AgentMiddleware,
|
||||
[_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse],
|
||||
AgentMiddleware[AgentState, None],
|
||||
]
|
||||
| AgentMiddleware
|
||||
| AgentMiddleware[AgentState, None]
|
||||
):
|
||||
"""Create middleware with `wrap_tool_call` hook from a function.
|
||||
|
||||
@@ -1797,18 +1895,19 @@ def wrap_tool_call(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningToolResponse,
|
||||
) -> AgentMiddleware:
|
||||
func: _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse,
|
||||
) -> AgentMiddleware[AgentState, None]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
if is_async:
|
||||
async_func = cast("_AsyncCallableReturningToolResponse", func)
|
||||
|
||||
async def async_wrapped(
|
||||
_self: AgentMiddleware,
|
||||
_self: AgentMiddleware[AgentState, None],
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
return await func(request, handler) # type: ignore[arg-type,misc]
|
||||
return await async_func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
||||
@@ -1824,12 +1923,14 @@ def wrap_tool_call(
|
||||
},
|
||||
)()
|
||||
|
||||
sync_func = cast("_SyncCallableReturningToolResponse", func)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware,
|
||||
_self: AgentMiddleware[AgentState, None],
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
return func(request, handler)
|
||||
return sync_func(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
||||
|
||||
|
||||
@@ -0,0 +1,390 @@
|
||||
"""Tests demonstrating proper typing support for middleware.
|
||||
|
||||
This test file verifies that:
|
||||
1. ModelRequest is properly generic over StateT and ContextT
|
||||
2. Async middleware decorators work without type errors
|
||||
3. Sync middleware decorators work without type errors
|
||||
4. Custom context types flow through properly
|
||||
5. Handler callbacks have correct async/sync signatures
|
||||
|
||||
These tests should pass mypy type checking without any type: ignore comments.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypedDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
before_model,
|
||||
dynamic_prompt,
|
||||
wrap_model_call,
|
||||
wrap_tool_call,
|
||||
)
|
||||
|
||||
|
||||
# Custom context type for testing
|
||||
class ServiceContext(TypedDict):
|
||||
"""Custom context for service-level information."""
|
||||
|
||||
user_id: str
|
||||
session_id: str
|
||||
environment: str
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
"""Custom state extending AgentState."""
|
||||
|
||||
custom_field: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 1: ModelRequest generic typing with custom context
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_model_request_generic_context_typing() -> None:
|
||||
"""Test that ModelRequest[StateT, ContextT] properly types the state and runtime fields."""
|
||||
# Create a mock model
|
||||
mock_model = MagicMock()
|
||||
|
||||
# Create ModelRequest with explicit state and context type annotation
|
||||
request: ModelRequest[AgentState, ServiceContext] = ModelRequest(
|
||||
model=mock_model,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
)
|
||||
|
||||
# The request should be created without type errors
|
||||
assert request.model == mock_model
|
||||
assert len(request.messages) == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 2: Sync dynamic_prompt decorator with proper typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_sync_dynamic_prompt_typing() -> None:
|
||||
"""Test that sync @dynamic_prompt decorator works without type errors."""
|
||||
|
||||
@dynamic_prompt
|
||||
def my_prompt(request: ModelRequest[AgentState, ServiceContext]) -> str:
|
||||
# This should work without type: ignore - accessing generic ModelRequest
|
||||
return f"System prompt for messages: {len(request.messages)}"
|
||||
|
||||
# The decorator should return an AgentMiddleware
|
||||
assert isinstance(my_prompt, AgentMiddleware)
|
||||
|
||||
|
||||
def test_sync_dynamic_prompt_returning_system_message() -> None:
|
||||
"""Test that sync @dynamic_prompt can return SystemMessage."""
|
||||
|
||||
@dynamic_prompt
|
||||
def my_prompt(request: ModelRequest[AgentState, None]) -> SystemMessage:
|
||||
return SystemMessage(content="You are a helpful assistant")
|
||||
|
||||
assert isinstance(my_prompt, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 3: Async dynamic_prompt decorator with proper typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_async_dynamic_prompt_typing() -> None:
|
||||
"""Test that async @dynamic_prompt decorator works without type errors."""
|
||||
|
||||
@dynamic_prompt
|
||||
async def my_async_prompt(request: ModelRequest[AgentState, ServiceContext]) -> str:
|
||||
# Async function should work without type errors
|
||||
return "Async system prompt"
|
||||
|
||||
assert isinstance(my_async_prompt, AgentMiddleware)
|
||||
|
||||
|
||||
def test_async_dynamic_prompt_returning_system_message() -> None:
|
||||
"""Test that async @dynamic_prompt can return SystemMessage."""
|
||||
|
||||
@dynamic_prompt
|
||||
async def my_async_prompt(request: ModelRequest[AgentState, None]) -> SystemMessage:
|
||||
return SystemMessage(content="Async system message")
|
||||
|
||||
assert isinstance(my_async_prompt, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 4: Sync wrap_model_call decorator with proper handler typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_sync_wrap_model_call_typing() -> None:
|
||||
"""Test that sync @wrap_model_call decorator properly types the handler."""
|
||||
|
||||
@wrap_model_call
|
||||
def retry_middleware(
|
||||
request: ModelRequest[AgentState, ServiceContext],
|
||||
handler: Callable[[ModelRequest[AgentState, ServiceContext]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Handler should be typed as sync - no Awaitable
|
||||
return handler(request)
|
||||
|
||||
assert isinstance(retry_middleware, AgentMiddleware)
|
||||
|
||||
|
||||
def test_sync_wrap_model_call_returning_ai_message() -> None:
|
||||
"""Test that sync @wrap_model_call can return AIMessage directly."""
|
||||
|
||||
@wrap_model_call
|
||||
def simple_middleware(
|
||||
request: ModelRequest[AgentState, None],
|
||||
handler: Callable[[ModelRequest[AgentState, None]], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
# Can return AIMessage directly (converted automatically)
|
||||
return AIMessage(content="Simple response")
|
||||
|
||||
assert isinstance(simple_middleware, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 5: Async wrap_model_call decorator with proper handler typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_async_wrap_model_call_typing() -> None:
|
||||
"""Test that async @wrap_model_call decorator properly types the async handler."""
|
||||
|
||||
@wrap_model_call
|
||||
async def async_retry_middleware(
|
||||
request: ModelRequest[AgentState, ServiceContext],
|
||||
handler: Callable[[ModelRequest[AgentState, ServiceContext]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
# Handler should be typed as async - returns Awaitable
|
||||
return await handler(request)
|
||||
|
||||
assert isinstance(async_retry_middleware, AgentMiddleware)
|
||||
|
||||
|
||||
def test_async_wrap_model_call_with_error_handling() -> None:
|
||||
"""Test async @wrap_model_call with try/except pattern."""
|
||||
|
||||
@wrap_model_call
|
||||
async def error_handling_middleware(
|
||||
request: ModelRequest[AgentState, None],
|
||||
handler: Callable[[ModelRequest[AgentState, None]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception:
|
||||
return AIMessage(content="Error occurred")
|
||||
|
||||
assert isinstance(error_handling_middleware, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 6: Sync wrap_tool_call decorator with proper handler typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_sync_wrap_tool_call_typing() -> None:
|
||||
"""Test that sync @wrap_tool_call decorator properly types the handler."""
|
||||
|
||||
@wrap_tool_call
|
||||
def tool_error_handler(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=request.tool_call["id"],
|
||||
)
|
||||
|
||||
assert isinstance(tool_error_handler, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 7: Async wrap_tool_call decorator with proper handler typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_async_wrap_tool_call_typing() -> None:
|
||||
"""Test that async @wrap_tool_call decorator properly types the async handler."""
|
||||
|
||||
@wrap_tool_call
|
||||
async def async_tool_error_handler(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
return ToolMessage(
|
||||
content=str(e),
|
||||
tool_call_id=request.tool_call["id"],
|
||||
)
|
||||
|
||||
assert isinstance(async_tool_error_handler, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 8: before_model/after_model decorators with custom state
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_before_model_with_custom_state_typing() -> None:
|
||||
"""Test @before_model decorator with custom state schema."""
|
||||
|
||||
@before_model(state_schema=CustomState)
|
||||
def log_before_model(
|
||||
state: CustomState,
|
||||
runtime: object, # Runtime type comes from langgraph
|
||||
) -> dict[str, object] | None:
|
||||
# Should have access to custom_field without type errors
|
||||
_ = state.get("custom_field") # Access custom field
|
||||
return None
|
||||
|
||||
assert isinstance(log_before_model, AgentMiddleware)
|
||||
assert log_before_model.state_schema == CustomState
|
||||
|
||||
|
||||
def test_after_model_with_custom_state_typing() -> None:
|
||||
"""Test @after_model decorator with custom state schema."""
|
||||
|
||||
@after_model(state_schema=CustomState)
|
||||
def log_after_model(
|
||||
state: CustomState,
|
||||
runtime: object,
|
||||
) -> dict[str, object] | None:
|
||||
return {"custom_field": "updated"}
|
||||
|
||||
assert isinstance(log_after_model, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 9: before_agent/after_agent decorators
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_before_agent_async_typing() -> None:
|
||||
"""Test async @before_agent decorator."""
|
||||
|
||||
@before_agent
|
||||
async def setup_agent(
|
||||
state: AgentState,
|
||||
runtime: object,
|
||||
) -> dict[str, object] | None:
|
||||
return None
|
||||
|
||||
assert isinstance(setup_agent, AgentMiddleware)
|
||||
|
||||
|
||||
def test_after_agent_async_typing() -> None:
|
||||
"""Test async @after_agent decorator."""
|
||||
|
||||
@after_agent
|
||||
async def cleanup_agent(
|
||||
state: AgentState,
|
||||
runtime: object,
|
||||
) -> dict[str, object] | None:
|
||||
return None
|
||||
|
||||
assert isinstance(cleanup_agent, AgentMiddleware)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 10: Class-based middleware with proper generic typing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TypedMiddleware(AgentMiddleware[CustomState, ServiceContext]):
|
||||
"""Class-based middleware with explicit type parameters."""
|
||||
|
||||
state_schema = CustomState
|
||||
|
||||
def before_model(
|
||||
self,
|
||||
state: CustomState,
|
||||
runtime: object,
|
||||
) -> dict[str, object] | None:
|
||||
# State is properly typed as CustomState
|
||||
return None
|
||||
|
||||
|
||||
def test_class_based_middleware_typing() -> None:
|
||||
"""Test class-based middleware with explicit generics."""
|
||||
middleware = TypedMiddleware()
|
||||
assert middleware.state_schema == CustomState
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 11: ModelRequest.override() preserves generic type
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_model_request_override_preserves_generic() -> None:
|
||||
"""Test that ModelRequest.override() returns properly typed ModelRequest."""
|
||||
mock_model = MagicMock()
|
||||
|
||||
request: ModelRequest[AgentState, ServiceContext] = ModelRequest(
|
||||
model=mock_model,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
)
|
||||
|
||||
# override() should return ModelRequest[AgentState, ServiceContext], not ModelRequest[Any, Any]
|
||||
new_request = request.override(system_message=SystemMessage(content="New system prompt"))
|
||||
|
||||
# This should be type-safe
|
||||
assert new_request.system_message is not None
|
||||
assert new_request.system_message.content == "New system prompt"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test 12: Multiple middleware in a list (simulating create_agent usage)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_middleware_list_typing() -> None:
|
||||
"""Test that middleware can be collected in a properly typed list."""
|
||||
|
||||
@dynamic_prompt
|
||||
async def system_prompt(request: ModelRequest[AgentState, ServiceContext]) -> SystemMessage:
|
||||
return SystemMessage(content="System")
|
||||
|
||||
@wrap_model_call
|
||||
async def censor_response(
|
||||
request: ModelRequest[AgentState, ServiceContext],
|
||||
handler: Callable[[ModelRequest[AgentState, ServiceContext]], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(request)
|
||||
|
||||
@wrap_tool_call
|
||||
async def handle_errors(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as e:
|
||||
return ToolMessage(content=str(e), tool_call_id=request.tool_call["id"])
|
||||
|
||||
# All middleware should be assignable to a list of AgentMiddleware
|
||||
# Note: The decorators return AgentMiddleware with inferred generic parameters
|
||||
middleware_list: list[AgentMiddleware[AgentState, ServiceContext]] = [
|
||||
system_prompt,
|
||||
censor_response,
|
||||
]
|
||||
|
||||
assert len(middleware_list) == 2
|
||||
Reference in New Issue
Block a user