mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
chore(langchain): activate mypy warn_return_any (#34249)
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
2b4735712c
commit
23ce677870
@@ -454,7 +454,9 @@ def _resolve_schema(
|
||||
if not should_omit:
|
||||
all_annotations[field_name] = field_type
|
||||
|
||||
return TypedDict(schema_name, all_annotations) # type: ignore[operator]
|
||||
# `TypedDict` dynamically creates a class, but type checkers don't infer that
|
||||
# the runtime result satisfies this function's `type` return contract.
|
||||
return cast("type", TypedDict(schema_name, all_annotations)) # type: ignore[operator]
|
||||
|
||||
|
||||
def _extract_metadata(type_: type) -> list[Any]:
|
||||
@@ -493,7 +495,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
and sync_method is not base_sync_method
|
||||
and hasattr(sync_method, "__can_jump_to__")
|
||||
):
|
||||
return sync_method.__can_jump_to__
|
||||
# `hasattr` proves the metadata exists at runtime, but not its value type.
|
||||
return cast("list[JumpTo]", sync_method.__can_jump_to__)
|
||||
|
||||
# Try async method - only if it's overridden from base class
|
||||
async_method = getattr(middleware.__class__, f"a{hook_name}", None)
|
||||
@@ -502,7 +505,8 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
and async_method is not base_async_method
|
||||
and hasattr(async_method, "__can_jump_to__")
|
||||
):
|
||||
return async_method.__can_jump_to__
|
||||
# `hasattr` proves the metadata exists at runtime, but not its value type.
|
||||
return cast("list[JumpTo]", async_method.__can_jump_to__)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@@ -764,7 +764,7 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
tool_call_id: str | None,
|
||||
) -> Any:
|
||||
) -> ToolMessage | str:
|
||||
session = resources.session
|
||||
|
||||
if payload.get("restart"):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass, field, replace
|
||||
from inspect import iscoroutinefunction
|
||||
@@ -12,17 +13,12 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
# Needed as top level import for Pydantic schema generation on AgentState
|
||||
import warnings
|
||||
from typing import TypeAlias
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
@@ -33,15 +29,15 @@ from langchain_core.messages import (
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.stream._mux import TransformerFactory
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
|
||||
@@ -811,38 +807,54 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable with `AgentState` and `Runtime` as arguments."""
|
||||
_SyncCallableWithStateAndRuntime = Callable[
|
||||
[StateT_contra, Runtime[ContextT]], dict[str, Any] | Command[Any] | None
|
||||
]
|
||||
_AsyncCallableWithStateAndRuntime = Callable[
|
||||
[StateT_contra, Runtime[ContextT]], Awaitable[dict[str, Any] | Command[Any] | None]
|
||||
]
|
||||
_CallableWithStateAndRuntime = (
|
||||
_SyncCallableWithStateAndRuntime[StateT_contra, ContextT]
|
||||
| _AsyncCallableWithStateAndRuntime[StateT_contra, ContextT]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
|
||||
"""Perform some logic with the state and runtime."""
|
||||
...
|
||||
_SyncCallableReturningSystemMessage = Callable[[ModelRequest[ContextT]], str | SystemMessage]
|
||||
_AsyncCallableReturningSystemMessage = Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[str | SystemMessage]
|
||||
]
|
||||
_CallableReturningSystemMessage = (
|
||||
_SyncCallableReturningSystemMessage[ContextT] | _AsyncCallableReturningSystemMessage[ContextT]
|
||||
)
|
||||
|
||||
|
||||
class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
||||
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
|
||||
"""Generate a system prompt string or SystemMessage based on the request."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT, ResponseT]): # type: ignore[misc]
|
||||
class _CallableReturningModelResponse(Protocol[ContextT, ResponseT]):
|
||||
"""Callable for model call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute model and returns `ModelResponse` or
|
||||
`AIMessage`.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
) -> ModelResponse[ResponseT] | AIMessage: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> Awaitable[ModelResponse[ResponseT] | AIMessage]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], ModelResponse[ResponseT] | Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT] | AIMessage | Awaitable[ModelResponse[ResponseT] | AIMessage]:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
|
||||
@@ -854,11 +866,27 @@ class _CallableReturningToolResponse(Protocol):
|
||||
`Command`.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
) -> ToolMessage | Command[Any]: ...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> Awaitable[ToolMessage | Command[Any]]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
) -> ToolMessage | Command[Any] | Awaitable[ToolMessage | Command[Any]]:
|
||||
"""Intercept tool execution via handler callback."""
|
||||
...
|
||||
|
||||
@@ -1041,7 +1069,11 @@ def before_model(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
|
||||
state, runtime
|
||||
)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1051,22 +1083,29 @@ def before_model(
|
||||
"str", getattr(func, "__name__", "BeforeModelMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"abefore_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"abefore_model": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1075,15 +1114,20 @@ def before_model(
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_model": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_model": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -1201,7 +1245,11 @@ def after_model(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
|
||||
state, runtime
|
||||
)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1209,22 +1257,29 @@ def after_model(
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"aafter_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"aafter_model": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1233,15 +1288,20 @@ def after_model(
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_model": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_model": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -1392,7 +1452,11 @@ def before_agent(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
|
||||
state, runtime
|
||||
)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1402,22 +1466,29 @@ def before_agent(
|
||||
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"abefore_agent": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"abefore_agent": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1426,15 +1497,20 @@ def before_agent(
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_agent": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_agent": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -1553,7 +1629,11 @@ def after_agent(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
|
||||
state, runtime
|
||||
)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1561,22 +1641,29 @@ def after_agent(
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"aafter_agent": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"aafter_agent": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command[Any] | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
# `iscoroutinefunction` narrows `func` at runtime, but type checkers
|
||||
# cannot narrow this sync-or-async callable union.
|
||||
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
@@ -1585,15 +1672,20 @@ def after_agent(
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_agent": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_agent": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -1602,7 +1694,7 @@ def after_agent(
|
||||
|
||||
@overload
|
||||
def dynamic_prompt(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT],
|
||||
func: _CallableReturningSystemMessage[ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@@ -1610,16 +1702,16 @@ def dynamic_prompt(
|
||||
def dynamic_prompt(
|
||||
func: None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningSystemMessage[StateT, ContextT]],
|
||||
[_CallableReturningSystemMessage[ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def dynamic_prompt(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
|
||||
func: _CallableReturningSystemMessage[ContextT] | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningSystemMessage[StateT, ContextT]],
|
||||
[_CallableReturningSystemMessage[ContextT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
@@ -1673,7 +1765,7 @@ def dynamic_prompt(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningSystemMessage[StateT, ContextT],
|
||||
func: _CallableReturningSystemMessage[ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
@@ -1684,7 +1776,7 @@ def dynamic_prompt(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
prompt = await cast("_AsyncCallableReturningSystemMessage[ContextT]", func)(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1693,22 +1785,27 @@ def dynamic_prompt(
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": [],
|
||||
"awrap_model_call": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": [],
|
||||
"awrap_model_call": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[Any]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request)
|
||||
prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1721,7 +1818,7 @@ def dynamic_prompt(
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[Any]]],
|
||||
) -> ModelResponse[Any] | AIMessage:
|
||||
# Delegate to sync function
|
||||
prompt = cast("Callable[[ModelRequest[ContextT]], SystemMessage | str]", func)(request)
|
||||
prompt = cast("_SyncCallableReturningSystemMessage[ContextT]", func)(request)
|
||||
if isinstance(prompt, SystemMessage):
|
||||
request = request.override(system_message=prompt)
|
||||
else:
|
||||
@@ -1730,16 +1827,21 @@ def dynamic_prompt(
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": [],
|
||||
"wrap_model_call": wrapped,
|
||||
"awrap_model_call": async_wrapped_from_sync,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": [],
|
||||
"wrap_model_call": wrapped,
|
||||
"awrap_model_call": async_wrapped_from_sync,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -1748,7 +1850,7 @@ def dynamic_prompt(
|
||||
|
||||
@overload
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
func: _CallableReturningModelResponse[ContextT, ResponseT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@@ -1760,20 +1862,20 @@ def wrap_model_call(
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
|
||||
[_CallableReturningModelResponse[ContextT, ResponseT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]: ...
|
||||
|
||||
|
||||
def wrap_model_call(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT] | None = None,
|
||||
func: _CallableReturningModelResponse[ContextT, ResponseT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[
|
||||
[_CallableReturningModelResponse[StateT, ContextT, ResponseT]],
|
||||
[_CallableReturningModelResponse[ContextT, ResponseT]],
|
||||
AgentMiddleware[StateT, ContextT],
|
||||
]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
@@ -1854,7 +1956,7 @@ def wrap_model_call(
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableReturningModelResponse[StateT, ContextT, ResponseT],
|
||||
func: _CallableReturningModelResponse[ContextT, ResponseT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
@@ -1865,21 +1967,26 @@ def wrap_model_call(
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
return await func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"awrap_model_call": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"awrap_model_call": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware[StateT, ContextT],
|
||||
@@ -1890,15 +1997,20 @@ def wrap_model_call(
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"wrap_model_call": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware[StateT, ContextT]",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"wrap_model_call": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
@@ -2025,21 +2137,26 @@ def wrap_tool_call(
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
return await func(request, handler) # type: ignore[arg-type,misc]
|
||||
return await func(request, handler)
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": tools or [],
|
||||
"awrap_tool_call": async_wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": tools or [],
|
||||
"awrap_tool_call": async_wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
def wrapped(
|
||||
_self: AgentMiddleware,
|
||||
@@ -2050,15 +2167,20 @@ def wrap_tool_call(
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": tools or [],
|
||||
"wrap_tool_call": wrapped,
|
||||
},
|
||||
)()
|
||||
# `type(...)` builds the correct middleware subclass at runtime, but
|
||||
# type checkers cannot infer its generic `AgentMiddleware` parameters.
|
||||
return cast(
|
||||
"AgentMiddleware",
|
||||
type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": AgentState,
|
||||
"tools": tools or [],
|
||||
"wrap_tool_call": wrapped,
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
@@ -76,7 +76,7 @@ class StructuredOutputValidationError(StructuredOutputError):
|
||||
|
||||
def _parse_with_schema(
|
||||
schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any]
|
||||
) -> Any:
|
||||
) -> SchemaT | dict[str, Any]:
|
||||
"""Parse data using for any supported schema type.
|
||||
|
||||
Args:
|
||||
@@ -92,9 +92,10 @@ def _parse_with_schema(
|
||||
ValueError: If parsing fails
|
||||
"""
|
||||
if schema_kind == "json_schema":
|
||||
# Raw JSON schema has no corresponding Python type to instantiate.
|
||||
return data
|
||||
try:
|
||||
adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
|
||||
adapter = TypeAdapter[SchemaT](schema)
|
||||
return adapter.validate_python(data)
|
||||
except Exception as e:
|
||||
schema_name = getattr(schema, "__name__", str(schema))
|
||||
@@ -344,7 +345,7 @@ class OutputToolBinding(Generic[SchemaT]):
|
||||
),
|
||||
)
|
||||
|
||||
def parse(self, tool_args: dict[str, Any]) -> SchemaT:
|
||||
def parse(self, tool_args: dict[str, Any]) -> SchemaT | dict[str, Any]:
|
||||
"""Parse tool arguments according to the schema.
|
||||
|
||||
Args:
|
||||
@@ -391,7 +392,7 @@ class ProviderStrategyBinding(Generic[SchemaT]):
|
||||
schema_kind=schema_spec.schema_kind,
|
||||
)
|
||||
|
||||
def parse(self, response: AIMessage) -> SchemaT:
|
||||
def parse(self, response: AIMessage) -> SchemaT | dict[str, Any]:
|
||||
"""Parse `AIMessage` content according to the schema.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1017,17 +1017,37 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
yield x
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
# `bind_tools` is implemented by concrete models because tool binding is
|
||||
# provider-specific. A configurable model may not have a concrete model instance
|
||||
# yet, since invocation config can choose it later. Save the `bind_tools` tools
|
||||
# and kwargs now. When `_model` later builds the selected provider model, it calls
|
||||
# `selected_model.bind_tools(tools, **kwargs)` and returns that runnable.
|
||||
# Cast so callers still get the public return type.
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
return cast(
|
||||
"Runnable[LanguageModelInput, AIMessage]",
|
||||
self.__getattr__("bind_tools")(tools, **kwargs),
|
||||
)
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
# `with_structured_output` is implemented by concrete models because structured
|
||||
# output support is provider-specific. A configurable model may not have a
|
||||
# concrete model instance yet, since invocation config can choose it later. Save
|
||||
# the structured-output schema and kwargs now. When `_model` later builds the
|
||||
# selected provider model, it calls
|
||||
# `selected_model.with_structured_output(schema, **kwargs)` and returns that
|
||||
# runnable.
|
||||
# Cast so callers still get the public return type.
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: dict[str, Any] | type[BaseModel],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]:
|
||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
||||
return cast(
|
||||
"Runnable[LanguageModelInput, dict[str, Any] | BaseModel]",
|
||||
self.__getattr__("with_structured_output")(schema, **kwargs),
|
||||
)
|
||||
|
||||
@@ -116,9 +116,6 @@ exclude = [
|
||||
"tests/unit_tests/agents/test_.*\\.py",
|
||||
]
|
||||
|
||||
# TODO: activate for 'strict' checking
|
||||
warn_return_any = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["pytest_socket.*", "vcr.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
Reference in New Issue
Block a user