Compare commits

...

2 Commits

Author SHA1 Message Date
Sydney Runkle
4cc3a15af8 more generics 2025-11-21 11:31:38 -05:00
Sydney Runkle
bacafb2fc6 improved types 2025-11-19 15:42:28 -05:00
2 changed files with 191 additions and 145 deletions

View File

@@ -22,17 +22,20 @@ from langgraph.graph.state import StateGraph
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
from langgraph.runtime import Runtime # noqa: TC002
from langgraph.types import Command, Send
from langgraph.typing import ContextT # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
AsyncModelCallHandler,
JumpTo,
ModelCallHandler,
ModelRequest,
ModelResponse,
OmitFromSchema,
ResponseT,
StateT,
StateT_co,
_InputAgentState,
_OutputAgentState,
@@ -86,13 +89,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
def _chain_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse,
]
| None
@@ -140,8 +143,8 @@ def _chain_model_call_handlers(
single_handler = handlers[0]
def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
result = single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -150,25 +153,25 @@ def _chain_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse | AIMessage,
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
ModelResponse,
]:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
def inner_handler(req: ModelRequest) -> ModelResponse:
def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
inner_result = inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -185,8 +188,8 @@ def _chain_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = result(request, handler)
@@ -198,13 +201,13 @@ def _chain_model_call_handlers(
def _chain_async_model_call_handlers(
handlers: Sequence[
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
]
],
) -> (
Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse],
]
| None
@@ -225,8 +228,8 @@ def _chain_async_model_call_handlers(
single_handler = handlers[0]
async def normalized_single(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
result = await single_handler(request, handler)
return _normalize_to_model_response(result)
@@ -235,25 +238,25 @@ def _chain_async_model_call_handlers(
def compose_two(
outer: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
],
inner: Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse | AIMessage],
],
) -> Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
Awaitable[ModelResponse],
]:
"""Compose two async handlers where outer wraps inner."""
async def composed(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# Create a wrapper that calls inner with the base handler and normalizes
async def inner_handler(req: ModelRequest) -> ModelResponse:
async def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
inner_result = await inner(req, handler)
return _normalize_to_model_response(inner_result)
@@ -270,8 +273,8 @@ def _chain_async_model_call_handlers(
# Wrap to ensure final return type is exactly ModelResponse
async def final_normalized(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelResponse:
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
final_result = await result(request, handler)
@@ -546,9 +549,9 @@ def create_agent( # noqa: PLR0915
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
*,
system_prompt: str | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
state_schema: type[AgentState[ResponseT]] | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
context_schema: type[ContextT] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
@@ -968,7 +971,9 @@ def create_agent( # noqa: PLR0915
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(
request: ModelRequest[StateT, ContextT],
) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.
Performs auto-detection of strategy if needed based on model capabilities.
@@ -1082,7 +1087,7 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings), None
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
def _execute_model_sync(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
"""Execute model and return response.
This is the core model execution logic wrapped by `wrap_model_call` handlers.
@@ -1106,9 +1111,9 @@ def create_agent( # noqa: PLR0915
structured_response=structured_response,
)
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
def model_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
request = ModelRequest[StateT, ContextT](
model=model,
tools=default_tools,
system_prompt=system_prompt,
@@ -1133,7 +1138,7 @@ def create_agent( # noqa: PLR0915
return state_updates
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
async def _execute_model_async(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
"""Execute model asynchronously and return response.
This is the core async model execution logic wrapped by `wrap_model_call`
@@ -1159,9 +1164,9 @@ def create_agent( # noqa: PLR0915
structured_response=structured_response,
)
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
async def amodel_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
request = ModelRequest[StateT, ContextT](
model=model,
tools=default_tools,
system_prompt=system_prompt,

View File

@@ -45,12 +45,16 @@ if TYPE_CHECKING:
__all__ = [
"AgentMiddleware",
"AgentState",
"AsyncModelCallHandler",
"AsyncToolCallHandler",
"ContextT",
"ModelCallHandler",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
"ResponseT",
"StateT_co",
"ToolCallHandler",
"ToolCallRequest",
"ToolCallWrapper",
"after_agent",
@@ -68,96 +72,6 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for `ModelRequest.override()` method."""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage]
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
model_settings: dict[str, Any]
@dataclass
class ModelRequest:
"""Model request information for the agent."""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: AgentState
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value: Any) -> None:
"""Set an attribute with a deprecation warning.
Direct attribute assignment on `ModelRequest` is deprecated. Use the
`override()` method instead to create a new request with modified attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
import warnings
# Allow setting attributes during __init__ (when object is being constructed)
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
object.__setattr__(self, name, value)
else:
warnings.warn(
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
f"Use request.override({name}=...) instead to create a new request "
f"with the modified attribute.",
DeprecationWarning,
stacklevel=2,
)
object.__setattr__(self, name, value)
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
"""Replace the request with a new request with the given overrides.
Returns a new `ModelRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override.
Supported keys:
- `model`: `BaseChatModel` instance
- `system_prompt`: Optional system prompt string
- `messages`: `list` of messages
- `tool_choice`: Tool choice configuration
- `tools`: `list` of available tools
- `response_format`: Response format specification
- `model_settings`: Additional model settings
Returns:
New `ModelRequest` instance with specified overrides applied.
Examples:
!!! example "Create a new request with different model"
```python
new_request = request.override(model=different_model)
```
!!! example "Override multiple attributes"
```python
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
```
"""
return replace(self, **overrides)
@dataclass
class ModelResponse:
"""Response from model execution including messages and optional structured output.
@@ -183,6 +97,38 @@ Middleware can return either:
- `AIMessage`: Simplified return for simple use cases
"""
# Type aliases for model call handlers
ModelCallHandler: TypeAlias = "Callable[[ModelRequest[StateT, ContextT]], ModelResponse]"
"""`TypeAlias` for synchronous model call handler callback.
This is the handler function passed to `wrap_model_call` middleware that executes
the model request and returns a `ModelResponse`.
"""
AsyncModelCallHandler: TypeAlias = (
"Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]]"
)
"""`TypeAlias` for asynchronous model call handler callback.
This is the handler function passed to `awrap_model_call` middleware that executes
the model request and returns an awaitable `ModelResponse`.
"""
# Type aliases for tool call handlers
ToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], ToolMessage | Command]"
"""`TypeAlias` for synchronous tool call handler callback.
This is the handler function passed to `wrap_tool_call` middleware that executes
the tool call and returns a `ToolMessage` or `Command`.
"""
AsyncToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]"
"""`TypeAlias` for asynchronous tool call handler callback.
This is the handler function passed to `awrap_tool_call` middleware that executes
the tool call and returns an awaitable `ToolMessage` or `Command`.
"""
@dataclass
class OmitFromSchema:
@@ -231,6 +177,101 @@ StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for `ModelRequest.override()` method."""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage]
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
model_settings: dict[str, Any]
@dataclass
class ModelRequest(Generic[StateT, ContextT]):
"""Model request information for the agent.
Generic over `ContextT` for better type inference of the runtime context.
"""
model: BaseChatModel
system_prompt: str | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[BaseTool | dict]
response_format: ResponseFormat | None
state: StateT
runtime: Runtime[ContextT]
model_settings: dict[str, Any] = field(default_factory=dict)
def __setattr__(self, name: str, value: Any) -> None:
"""Set an attribute with a deprecation warning.
Direct attribute assignment on `ModelRequest` is deprecated. Use the
`override()` method instead to create a new request with modified attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
import warnings
# Allow setting attributes during __init__ (when object is being constructed)
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
object.__setattr__(self, name, value)
else:
warnings.warn(
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
f"Use request.override({name}=...) instead to create a new request "
f"with the modified attribute.",
DeprecationWarning,
stacklevel=2,
)
object.__setattr__(self, name, value)
def override(
self, **overrides: Unpack[_ModelRequestOverrides]
) -> ModelRequest[StateT, ContextT]:
"""Replace the request with a new request with the given overrides.
Returns a new `ModelRequest` instance with the specified attributes replaced.
This follows an immutable pattern, leaving the original request unchanged.
Args:
**overrides: Keyword arguments for attributes to override.
Supported keys:
- `model`: `BaseChatModel` instance
- `system_prompt`: Optional system prompt string
- `messages`: `list` of messages
- `tool_choice`: Tool choice configuration
- `tools`: `list` of available tools
- `response_format`: Response format specification
- `model_settings`: Additional model settings
Returns:
New `ModelRequest` instance with specified overrides applied.
Examples:
!!! example "Create a new request with different model"
```python
new_request = request.override(model=different_model)
```
!!! example "Override multiple attributes"
```python
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
```
"""
return replace(self, **overrides)
class AgentMiddleware(Generic[StateT, ContextT]):
"""Base middleware class for an agent.
@@ -287,8 +328,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
"""Intercept and control model execution via handler callback.
@@ -382,8 +423,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
"""Intercept and control async model execution via handler callback.
@@ -443,7 +484,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution for retries, monitoring, or modification.
@@ -525,7 +566,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
"""Intercept and control async tool execution via handler callback.
@@ -605,7 +646,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
def __call__(self, request: ModelRequest[StateT_contra, ContextT]) -> str | Awaitable[str]:
"""Generate a system prompt string based on the request."""
...
@@ -619,8 +660,8 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT_contra, ContextT],
handler: ModelCallHandler[StateT_contra, ContextT],
) -> ModelCallResult:
"""Intercept model execution via handler callback."""
...
@@ -636,7 +677,7 @@ class _CallableReturningToolResponse(Protocol):
def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
"""Intercept tool execution via handler callback."""
...
@@ -1365,8 +1406,8 @@ def dynamic_prompt(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
request = request.override(system_prompt=prompt)
@@ -1386,8 +1427,8 @@ def dynamic_prompt(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
prompt = cast("str", func(request))
request = request.override(system_prompt=prompt)
@@ -1395,8 +1436,8 @@ def dynamic_prompt(
async def async_wrapped_from_sync(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("str", func(request))
@@ -1537,8 +1578,8 @@ def wrap_model_call(
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
request: ModelRequest[StateT, ContextT],
handler: AsyncModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]
@@ -1558,8 +1599,8 @@ def wrap_model_call(
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
request: ModelRequest[StateT, ContextT],
handler: ModelCallHandler[StateT, ContextT],
) -> ModelCallResult:
return func(request, handler)
@@ -1698,7 +1739,7 @@ def wrap_tool_call(
async def async_wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
handler: AsyncToolCallHandler,
) -> ToolMessage | Command:
return await func(request, handler) # type: ignore[arg-type,misc]
@@ -1719,7 +1760,7 @@ def wrap_tool_call(
def wrapped(
self: AgentMiddleware, # noqa: ARG001
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
handler: ToolCallHandler,
) -> ToolMessage | Command:
return func(request, handler)