mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-13 07:52:48 +00:00
Compare commits
2 Commits
langchain-
...
sr/typing-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cc3a15af8 | ||
|
|
bacafb2fc6 |
@@ -22,17 +22,20 @@ from langgraph.graph.state import StateGraph
|
||||
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
|
||||
from langgraph.runtime import Runtime # noqa: TC002
|
||||
from langgraph.types import Command, Send
|
||||
from langgraph.typing import ContextT # noqa: TC002
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
AsyncModelCallHandler,
|
||||
JumpTo,
|
||||
ModelCallHandler,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromSchema,
|
||||
ResponseT,
|
||||
StateT,
|
||||
StateT_co,
|
||||
_InputAgentState,
|
||||
_OutputAgentState,
|
||||
@@ -86,13 +89,13 @@ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResp
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse,
|
||||
]
|
||||
| None
|
||||
@@ -140,8 +143,8 @@ def _chain_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
result = single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -150,25 +153,25 @@ def _chain_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse | AIMessage,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], ModelResponse]],
|
||||
[ModelRequest[StateT, ContextT], ModelCallHandler[StateT, ContextT]],
|
||||
ModelResponse,
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
inner_result = inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -185,8 +188,8 @@ def _chain_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = result(request, handler)
|
||||
@@ -198,13 +201,13 @@ def _chain_model_call_handlers(
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
| None
|
||||
@@ -225,8 +228,8 @@ def _chain_async_model_call_handlers(
|
||||
single_handler = handlers[0]
|
||||
|
||||
async def normalized_single(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
result = await single_handler(request, handler)
|
||||
return _normalize_to_model_response(result)
|
||||
@@ -235,25 +238,25 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse | AIMessage],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
|
||||
[ModelRequest[StateT, ContextT], AsyncModelCallHandler[StateT, ContextT]],
|
||||
Awaitable[ModelResponse],
|
||||
]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler and normalizes
|
||||
async def inner_handler(req: ModelRequest) -> ModelResponse:
|
||||
async def inner_handler(req: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
inner_result = await inner(req, handler)
|
||||
return _normalize_to_model_response(inner_result)
|
||||
|
||||
@@ -270,8 +273,8 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
# Wrap to ensure final return type is exactly ModelResponse
|
||||
async def final_normalized(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelResponse:
|
||||
# result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
|
||||
final_result = await result(request, handler)
|
||||
@@ -546,9 +549,9 @@ def create_agent( # noqa: PLR0915
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
context_schema: type[ContextT] | None = None,
|
||||
checkpointer: Checkpointer | None = None,
|
||||
store: BaseStore | None = None,
|
||||
@@ -968,7 +971,9 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
def _get_bound_model(
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
) -> tuple[Runnable, ResponseFormat | None]:
|
||||
"""Get the model with appropriate tool bindings.
|
||||
|
||||
Performs auto-detection of strategy if needed based on model capabilities.
|
||||
@@ -1082,7 +1087,7 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
|
||||
def _execute_model_sync(request: ModelRequest) -> ModelResponse:
|
||||
def _execute_model_sync(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
"""Execute model and return response.
|
||||
|
||||
This is the core model execution logic wrapped by `wrap_model_call` handlers.
|
||||
@@ -1106,9 +1111,9 @@ def create_agent( # noqa: PLR0915
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
def model_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
request = ModelRequest[StateT, ContextT](
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
@@ -1133,7 +1138,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return state_updates
|
||||
|
||||
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
||||
async def _execute_model_async(request: ModelRequest[StateT, ContextT]) -> ModelResponse:
|
||||
"""Execute model asynchronously and return response.
|
||||
|
||||
This is the core async model execution logic wrapped by `wrap_model_call`
|
||||
@@ -1159,9 +1164,9 @@ def create_agent( # noqa: PLR0915
|
||||
structured_response=structured_response,
|
||||
)
|
||||
|
||||
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
async def amodel_node(state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
request = ModelRequest[StateT, ContextT](
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
|
||||
@@ -45,12 +45,16 @@ if TYPE_CHECKING:
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"AsyncModelCallHandler",
|
||||
"AsyncToolCallHandler",
|
||||
"ContextT",
|
||||
"ModelCallHandler",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
"ResponseT",
|
||||
"StateT_co",
|
||||
"ToolCallHandler",
|
||||
"ToolCallRequest",
|
||||
"ToolCallWrapper",
|
||||
"after_agent",
|
||||
@@ -68,96 +72,6 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for `ModelRequest.override()` method."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage]
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
state: AgentState
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Set an attribute with a deprecation warning.
|
||||
|
||||
Direct attribute assignment on `ModelRequest` is deprecated. Use the
|
||||
`override()` method instead to create a new request with modified attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
# Allow setting attributes during __init__ (when object is being constructed)
|
||||
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
|
||||
f"Use request.override({name}=...) instead to create a new request "
|
||||
f"with the modified attribute.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
||||
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override.
|
||||
|
||||
Supported keys:
|
||||
|
||||
- `model`: `BaseChatModel` instance
|
||||
- `system_prompt`: Optional system prompt string
|
||||
- `messages`: `list` of messages
|
||||
- `tool_choice`: Tool choice configuration
|
||||
- `tools`: `list` of available tools
|
||||
- `response_format`: Response format specification
|
||||
- `model_settings`: Additional model settings
|
||||
|
||||
Returns:
|
||||
New `ModelRequest` instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
!!! example "Create a new request with different model"
|
||||
|
||||
```python
|
||||
new_request = request.override(model=different_model)
|
||||
```
|
||||
|
||||
!!! example "Override multiple attributes"
|
||||
|
||||
```python
|
||||
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from model execution including messages and optional structured output.
|
||||
@@ -183,6 +97,38 @@ Middleware can return either:
|
||||
- `AIMessage`: Simplified return for simple use cases
|
||||
"""
|
||||
|
||||
# Type aliases for model call handlers
|
||||
ModelCallHandler: TypeAlias = "Callable[[ModelRequest[StateT, ContextT]], ModelResponse]"
|
||||
"""`TypeAlias` for synchronous model call handler callback.
|
||||
|
||||
This is the handler function passed to `wrap_model_call` middleware that executes
|
||||
the model request and returns a `ModelResponse`.
|
||||
"""
|
||||
|
||||
AsyncModelCallHandler: TypeAlias = (
|
||||
"Callable[[ModelRequest[StateT, ContextT]], Awaitable[ModelResponse]]"
|
||||
)
|
||||
"""`TypeAlias` for asynchronous model call handler callback.
|
||||
|
||||
This is the handler function passed to `awrap_model_call` middleware that executes
|
||||
the model request and returns an awaitable `ModelResponse`.
|
||||
"""
|
||||
|
||||
# Type aliases for tool call handlers
|
||||
ToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], ToolMessage | Command]"
|
||||
"""`TypeAlias` for synchronous tool call handler callback.
|
||||
|
||||
This is the handler function passed to `wrap_tool_call` middleware that executes
|
||||
the tool call and returns a `ToolMessage` or `Command`.
|
||||
"""
|
||||
|
||||
AsyncToolCallHandler: TypeAlias = "Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]"
|
||||
"""`TypeAlias` for asynchronous tool call handler callback.
|
||||
|
||||
This is the handler function passed to `awrap_tool_call` middleware that executes
|
||||
the tool call and returns an awaitable `ToolMessage` or `Command`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OmitFromSchema:
|
||||
@@ -231,6 +177,101 @@ StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant
|
||||
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
||||
|
||||
|
||||
class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for `ModelRequest.override()` method."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage]
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest(Generic[StateT, ContextT]):
|
||||
"""Model request information for the agent.
|
||||
|
||||
Generic over `ContextT` for better type inference of the runtime context.
|
||||
"""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
state: StateT
|
||||
runtime: Runtime[ContextT]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Set an attribute with a deprecation warning.
|
||||
|
||||
Direct attribute assignment on `ModelRequest` is deprecated. Use the
|
||||
`override()` method instead to create a new request with modified attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
import warnings
|
||||
|
||||
# Allow setting attributes during __init__ (when object is being constructed)
|
||||
if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Direct attribute assignment to ModelRequest.{name} is deprecated. "
|
||||
f"Use request.override({name}=...) instead to create a new request "
|
||||
f"with the modified attribute.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def override(
|
||||
self, **overrides: Unpack[_ModelRequestOverrides]
|
||||
) -> ModelRequest[StateT, ContextT]:
|
||||
"""Replace the request with a new request with the given overrides.
|
||||
|
||||
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
||||
|
||||
This follows an immutable pattern, leaving the original request unchanged.
|
||||
|
||||
Args:
|
||||
**overrides: Keyword arguments for attributes to override.
|
||||
|
||||
Supported keys:
|
||||
|
||||
- `model`: `BaseChatModel` instance
|
||||
- `system_prompt`: Optional system prompt string
|
||||
- `messages`: `list` of messages
|
||||
- `tool_choice`: Tool choice configuration
|
||||
- `tools`: `list` of available tools
|
||||
- `response_format`: Response format specification
|
||||
- `model_settings`: Additional model settings
|
||||
|
||||
Returns:
|
||||
New `ModelRequest` instance with specified overrides applied.
|
||||
|
||||
Examples:
|
||||
!!! example "Create a new request with different model"
|
||||
|
||||
```python
|
||||
new_request = request.override(model=different_model)
|
||||
```
|
||||
|
||||
!!! example "Override multiple attributes"
|
||||
|
||||
```python
|
||||
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
||||
```
|
||||
"""
|
||||
return replace(self, **overrides)
|
||||
|
||||
|
||||
class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Base middleware class for an agent.
|
||||
|
||||
@@ -287,8 +328,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
@@ -382,8 +423,8 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control async model execution via handler callback.
|
||||
|
||||
@@ -443,7 +484,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution for retries, monitoring, or modification.
|
||||
|
||||
@@ -525,7 +566,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
handler: AsyncToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept and control async tool execution via handler callback.
|
||||
|
||||
@@ -605,7 +646,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
||||
"""Callable that returns a prompt string given `ModelRequest` (contains state and runtime)."""
|
||||
|
||||
def __call__(self, request: ModelRequest) -> str | Awaitable[str]:
|
||||
def __call__(self, request: ModelRequest[StateT_contra, ContextT]) -> str | Awaitable[str]:
|
||||
"""Generate a system prompt string based on the request."""
|
||||
...
|
||||
|
||||
@@ -619,8 +660,8 @@ class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # typ
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT_contra, ContextT],
|
||||
handler: ModelCallHandler[StateT_contra, ContextT],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
@@ -636,7 +677,7 @@ class _CallableReturningToolResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
"""Intercept tool execution via handler callback."""
|
||||
...
|
||||
@@ -1365,8 +1406,8 @@ def dynamic_prompt(
|
||||
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
request = request.override(system_prompt=prompt)
|
||||
@@ -1386,8 +1427,8 @@ def dynamic_prompt(
|
||||
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
prompt = cast("str", func(request))
|
||||
request = request.override(system_prompt=prompt)
|
||||
@@ -1395,8 +1436,8 @@ def dynamic_prompt(
|
||||
|
||||
async def async_wrapped_from_sync(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
# Delegate to sync function
|
||||
prompt = cast("str", func(request))
|
||||
@@ -1537,8 +1578,8 @@ def wrap_model_call(
|
||||
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: AsyncModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
|
||||
@@ -1558,8 +1599,8 @@ def wrap_model_call(
|
||||
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
request: ModelRequest[StateT, ContextT],
|
||||
handler: ModelCallHandler[StateT, ContextT],
|
||||
) -> ModelCallResult:
|
||||
return func(request, handler)
|
||||
|
||||
@@ -1698,7 +1739,7 @@ def wrap_tool_call(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware, # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
handler: AsyncToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
return await func(request, handler) # type: ignore[arg-type,misc]
|
||||
|
||||
@@ -1719,7 +1760,7 @@ def wrap_tool_call(
|
||||
def wrapped(
|
||||
self: AgentMiddleware, # noqa: ARG001
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
handler: ToolCallHandler,
|
||||
) -> ToolMessage | Command:
|
||||
return func(request, handler)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user