mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 02:33:34 +00:00
Compare commits
2 Commits
mdrxy/tool
...
update_mod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c21b43fb4e | ||
|
|
05eed19605 |
@@ -32,7 +32,9 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
JumpTo,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromSchema,
|
||||
PublicAgentState,
|
||||
)
|
||||
@@ -87,14 +89,14 @@ class _InternalModelResponse:
|
||||
def _chain_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], AIMessage]],
|
||||
AIMessage,
|
||||
[ModelRequest, Callable[[ModelCall], ModelResponse]],
|
||||
ModelResponse,
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], AIMessage]],
|
||||
AIMessage,
|
||||
[ModelRequest, Callable[[ModelCall], ModelResponse]],
|
||||
ModelResponse,
|
||||
]
|
||||
| None
|
||||
):
|
||||
@@ -141,26 +143,26 @@ def _chain_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], AIMessage]],
|
||||
AIMessage,
|
||||
[ModelRequest, Callable[[ModelCall], ModelResponse]],
|
||||
ModelResponse,
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], AIMessage]],
|
||||
AIMessage,
|
||||
[ModelRequest, Callable[[ModelCall], ModelResponse]],
|
||||
ModelResponse,
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], AIMessage]],
|
||||
AIMessage,
|
||||
[ModelRequest, Callable[[ModelCall], ModelResponse]],
|
||||
ModelResponse,
|
||||
]:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler
|
||||
def inner_handler(req: ModelRequest) -> AIMessage:
|
||||
return inner(req, handler)
|
||||
def inner_handler(_model_call: ModelCall) -> ModelResponse:
|
||||
return inner(request, handler)
|
||||
|
||||
# Call outer with the wrapped inner as its handler
|
||||
return outer(request, inner_handler)
|
||||
@@ -178,14 +180,14 @@ def _chain_model_call_handlers(
|
||||
def _chain_async_model_call_handlers(
|
||||
handlers: Sequence[
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
|
||||
Awaitable[AIMessage],
|
||||
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
],
|
||||
) -> (
|
||||
Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
|
||||
Awaitable[AIMessage],
|
||||
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
]
|
||||
| None
|
||||
):
|
||||
@@ -205,26 +207,26 @@ def _chain_async_model_call_handlers(
|
||||
|
||||
def compose_two(
|
||||
outer: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
|
||||
Awaitable[AIMessage],
|
||||
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
],
|
||||
inner: Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
|
||||
Awaitable[AIMessage],
|
||||
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
],
|
||||
) -> Callable[
|
||||
[ModelRequest, Callable[[ModelRequest], Awaitable[AIMessage]]],
|
||||
Awaitable[AIMessage],
|
||||
[ModelRequest, Callable[[ModelCall], Awaitable[ModelResponse]]],
|
||||
Awaitable[ModelResponse],
|
||||
]:
|
||||
"""Compose two async handlers where outer wraps inner."""
|
||||
|
||||
async def composed(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
# Create a wrapper that calls inner with the base handler
|
||||
async def inner_handler(req: ModelRequest) -> AIMessage:
|
||||
return await inner(req, handler)
|
||||
async def inner_handler(_model_call: ModelCall) -> ModelResponse:
|
||||
return await inner(request, handler)
|
||||
|
||||
# Call outer with the wrapped inner as its handler
|
||||
return await outer(request, inner_handler)
|
||||
@@ -744,13 +746,13 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
def _get_bound_model(model_call: ModelCall) -> tuple[Runnable, ResponseFormat | None]:
|
||||
"""Get the model with appropriate tool bindings.
|
||||
|
||||
Performs auto-detection of strategy if needed based on model capabilities.
|
||||
|
||||
Args:
|
||||
request: The model request containing model, tools, and response format.
|
||||
model_call: The model call containing model, tools, and response format.
|
||||
|
||||
Returns:
|
||||
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
|
||||
@@ -765,7 +767,7 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Check if any requested tools are unknown CLIENT-SIDE tools
|
||||
unknown_tool_names = []
|
||||
for t in request.tools:
|
||||
for t in model_call.tools:
|
||||
# Only validate BaseTool instances (skip built-in dict tools)
|
||||
if isinstance(t, dict):
|
||||
continue
|
||||
@@ -782,7 +784,7 @@ def create_agent( # noqa: PLR0915
|
||||
"the 'tools' parameter\n"
|
||||
"2. If using custom middleware with tools, ensure "
|
||||
"they're registered via middleware.tools attribute\n"
|
||||
"3. Verify that tool names in ModelRequest.tools match "
|
||||
"3. Verify that tool names in ModelCall.tools match "
|
||||
"the actual tool.name values\n"
|
||||
"Note: Built-in provider tools (dict format) can be added dynamically."
|
||||
)
|
||||
@@ -790,22 +792,24 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Determine effective response format (auto-detect if needed)
|
||||
effective_response_format: ResponseFormat | None
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
if isinstance(model_call.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(model_call.model):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
effective_response_format = ProviderStrategy(
|
||||
schema=model_call.response_format.schema
|
||||
)
|
||||
else:
|
||||
# Model doesn't support provider strategy - use ToolStrategy
|
||||
effective_response_format = ToolStrategy(schema=request.response_format.schema)
|
||||
effective_response_format = ToolStrategy(schema=model_call.response_format.schema)
|
||||
else:
|
||||
# User explicitly specified a strategy - preserve it
|
||||
effective_response_format = request.response_format
|
||||
effective_response_format = model_call.response_format
|
||||
|
||||
# Build final tools list including structured output tools
|
||||
# request.tools now only contains BaseTool instances (converted from callables)
|
||||
# model_call.tools now only contains BaseTool instances (converted from callables)
|
||||
# and dicts (built-ins)
|
||||
final_tools = list(request.tools)
|
||||
final_tools = list(model_call.tools)
|
||||
if isinstance(effective_response_format, ToolStrategy):
|
||||
# Add structured output tools to final tools list
|
||||
structured_tools = [info.tool for info in structured_output_tools.values()]
|
||||
@@ -816,8 +820,8 @@ def create_agent( # noqa: PLR0915
|
||||
# Use provider-specific structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
final_tools, strict=True, **kwargs, **request.model_settings
|
||||
model_call.model.bind_tools(
|
||||
final_tools, strict=True, **kwargs, **model_call.model_settings
|
||||
),
|
||||
effective_response_format,
|
||||
)
|
||||
@@ -839,10 +843,10 @@ def create_agent( # noqa: PLR0915
|
||||
raise ValueError(msg)
|
||||
|
||||
# Force tool use if we have structured output tools
|
||||
tool_choice = "any" if structured_output_tools else request.tool_choice
|
||||
tool_choice = "any" if structured_output_tools else model_call.tool_choice
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
final_tools, tool_choice=tool_choice, **request.model_settings
|
||||
model_call.model.bind_tools(
|
||||
final_tools, tool_choice=tool_choice, **model_call.model_settings
|
||||
),
|
||||
effective_response_format,
|
||||
)
|
||||
@@ -850,145 +854,316 @@ def create_agent( # noqa: PLR0915
|
||||
# No structured output - standard model binding
|
||||
if final_tools:
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
final_tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
model_call.model.bind_tools(
|
||||
final_tools, tool_choice=model_call.tool_choice, **model_call.model_settings
|
||||
),
|
||||
None,
|
||||
)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
return model_call.model.bind(**model_call.model_settings), None
|
||||
|
||||
def _execute_model_sync(request: ModelRequest) -> _InternalModelResponse:
|
||||
"""Execute model and return result or exception.
|
||||
def _execute_model_sync(model_call: ModelCall) -> ModelResponse:
|
||||
"""Execute model and return ModelResponse with messages and structured output.
|
||||
|
||||
This is the core model execution logic wrapped by wrap_model_call handlers.
|
||||
"""
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
Handles model invocation, auto-detection of response format, and structured
|
||||
output processing.
|
||||
|
||||
output = model_.invoke(messages)
|
||||
return _InternalModelResponse(
|
||||
result=output,
|
||||
exception=None,
|
||||
effective_response_format=effective_response_format,
|
||||
)
|
||||
except Exception as error: # noqa: BLE001
|
||||
# Catch all exceptions from model invocation
|
||||
return _InternalModelResponse(
|
||||
result=None,
|
||||
exception=error,
|
||||
effective_response_format=None,
|
||||
)
|
||||
Args:
|
||||
model_call: The model call parameters.
|
||||
|
||||
Returns:
|
||||
ModelResponse with result (list of messages) and structured_response (if applicable).
|
||||
|
||||
Raises:
|
||||
Exception: Any exception from model invocation or structured output processing.
|
||||
"""
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(model_call)
|
||||
messages = model_call.messages
|
||||
if model_call.system_prompt:
|
||||
messages = [SystemMessage(model_call.system_prompt), *messages]
|
||||
|
||||
output: AIMessage = model_.invoke(messages)
|
||||
|
||||
# Handle structured output with provider strategy
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
if not output.tool_calls:
|
||||
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
effective_response_format.schema_spec
|
||||
)
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return ModelResponse(result=[output], structured_response=structured_response)
|
||||
return ModelResponse(result=[output])
|
||||
|
||||
# Handle structured output with tool strategy
|
||||
if (
|
||||
isinstance(effective_response_format, ToolStrategy)
|
||||
and isinstance(output, AIMessage)
|
||||
and output.tool_calls
|
||||
):
|
||||
structured_tool_calls = [
|
||||
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
|
||||
]
|
||||
|
||||
if structured_tool_calls:
|
||||
if len(structured_tool_calls) > 1:
|
||||
# Handle multiple structured outputs error
|
||||
tool_names = [tc["name"] for tc in structured_tool_calls]
|
||||
multiple_outputs_error = MultipleStructuredOutputsError(tool_names)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
multiple_outputs_error, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise multiple_outputs_error
|
||||
|
||||
# Add error messages and retry
|
||||
tool_messages = [
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tc["id"],
|
||||
name=tc["name"],
|
||||
)
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
return ModelResponse(result=[output, *tool_messages])
|
||||
|
||||
# Handle single structured output
|
||||
tool_call = structured_tool_calls[0]
|
||||
try:
|
||||
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
|
||||
tool_message_content = (
|
||||
effective_response_format.tool_message_content
|
||||
if effective_response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
result=[
|
||||
output,
|
||||
ToolMessage(
|
||||
content=tool_message_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
structured_response=structured_response,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
validation_error = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
validation_error, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise validation_error
|
||||
|
||||
return ModelResponse(
|
||||
result=[
|
||||
output,
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return ModelResponse(result=[output])
|
||||
|
||||
def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
# Create ModelCall with invocation parameters
|
||||
model_call = ModelCall(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Create ModelRequest with model_call + state + runtime
|
||||
request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Execute with or without handler
|
||||
effective_response_format: Any = None
|
||||
|
||||
# Define base handler that executes the model
|
||||
def base_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal effective_response_format
|
||||
internal_response = _execute_model_sync(req)
|
||||
if internal_response.exception is not None:
|
||||
raise internal_response.exception
|
||||
if internal_response.result is None:
|
||||
msg = "Model execution succeeded but returned no result"
|
||||
raise RuntimeError(msg)
|
||||
effective_response_format = internal_response.effective_response_format
|
||||
return internal_response.result
|
||||
|
||||
# Execute with or without middleware handlers
|
||||
# Handler returns ModelResponse with messages and structured_response
|
||||
if wrap_model_call_handler is None:
|
||||
# No handlers - execute directly
|
||||
output = base_handler(request)
|
||||
response = _execute_model_sync(model_call)
|
||||
else:
|
||||
# Call composed handler with base handler
|
||||
output = wrap_model_call_handler(request, base_handler)
|
||||
return {
|
||||
response = wrap_model_call_handler(request, _execute_model_sync)
|
||||
|
||||
# Build result dict with model call counts and messages
|
||||
result: dict[str, Any] = {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
"messages": response.result,
|
||||
}
|
||||
|
||||
async def _execute_model_async(request: ModelRequest) -> _InternalModelResponse:
|
||||
"""Execute model asynchronously and return result or exception.
|
||||
# Add structured response if present
|
||||
if response.structured_response is not None:
|
||||
result["structured_response"] = response.structured_response
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_model_async(model_call: ModelCall) -> ModelResponse:
|
||||
"""Execute model asynchronously and return ModelResponse.
|
||||
|
||||
Returns ModelResponse with messages and structured output.
|
||||
This is the core async model execution logic wrapped by wrap_model_call handlers.
|
||||
"""
|
||||
try:
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
Handles model invocation, auto-detection of response format, and structured
|
||||
output processing.
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
return _InternalModelResponse(
|
||||
result=output,
|
||||
exception=None,
|
||||
effective_response_format=effective_response_format,
|
||||
)
|
||||
except Exception as error: # noqa: BLE001
|
||||
# Catch all exceptions from model invocation
|
||||
return _InternalModelResponse(
|
||||
result=None,
|
||||
exception=error,
|
||||
effective_response_format=None,
|
||||
)
|
||||
Args:
|
||||
model_call: The model call parameters.
|
||||
|
||||
Returns:
|
||||
ModelResponse with result (list of messages) and structured_response (if applicable).
|
||||
|
||||
Raises:
|
||||
Exception: Any exception from model invocation or structured output processing.
|
||||
"""
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(model_call)
|
||||
messages = model_call.messages
|
||||
if model_call.system_prompt:
|
||||
messages = [SystemMessage(model_call.system_prompt), *messages]
|
||||
|
||||
output: AIMessage = await model_.ainvoke(messages)
|
||||
|
||||
# Handle structured output with provider strategy
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
if not output.tool_calls:
|
||||
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
effective_response_format.schema_spec
|
||||
)
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return ModelResponse(result=[output], structured_response=structured_response)
|
||||
return ModelResponse(result=[output])
|
||||
|
||||
# Handle structured output with tool strategy
|
||||
if (
|
||||
isinstance(effective_response_format, ToolStrategy)
|
||||
and isinstance(output, AIMessage)
|
||||
and output.tool_calls
|
||||
):
|
||||
structured_tool_calls = [
|
||||
tc for tc in output.tool_calls if tc["name"] in structured_output_tools
|
||||
]
|
||||
|
||||
if structured_tool_calls:
|
||||
if len(structured_tool_calls) > 1:
|
||||
# Handle multiple structured outputs error
|
||||
tool_names = [tc["name"] for tc in structured_tool_calls]
|
||||
multiple_outputs_error = MultipleStructuredOutputsError(tool_names)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
multiple_outputs_error, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise multiple_outputs_error
|
||||
|
||||
# Add error messages and retry
|
||||
tool_messages = [
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tc["id"],
|
||||
name=tc["name"],
|
||||
)
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
return ModelResponse(result=[output, *tool_messages])
|
||||
|
||||
# Handle single structured output
|
||||
tool_call = structured_tool_calls[0]
|
||||
try:
|
||||
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
|
||||
tool_message_content = (
|
||||
effective_response_format.tool_message_content
|
||||
if effective_response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
result=[
|
||||
output,
|
||||
ToolMessage(
|
||||
content=tool_message_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
structured_response=structured_response,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
validation_error = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
validation_error, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise validation_error
|
||||
|
||||
return ModelResponse(
|
||||
result=[
|
||||
output,
|
||||
ToolMessage(
|
||||
content=error_message,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return ModelResponse(result=[output])
|
||||
|
||||
async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
# Create ModelCall with invocation parameters
|
||||
model_call = ModelCall(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Create ModelRequest with model_call + state + runtime
|
||||
request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
# Execute with or without handler
|
||||
effective_response_format: Any = None
|
||||
|
||||
# Define base async handler that executes the model
|
||||
async def base_handler(req: ModelRequest) -> AIMessage:
|
||||
nonlocal effective_response_format
|
||||
internal_response = await _execute_model_async(req)
|
||||
if internal_response.exception is not None:
|
||||
raise internal_response.exception
|
||||
if internal_response.result is None:
|
||||
msg = "Model execution succeeded but returned no result"
|
||||
raise RuntimeError(msg)
|
||||
effective_response_format = internal_response.effective_response_format
|
||||
return internal_response.result
|
||||
|
||||
# Execute with or without middleware handlers
|
||||
# Handler returns ModelResponse with messages and structured_response
|
||||
if awrap_model_call_handler is None:
|
||||
# No async handlers - execute directly
|
||||
output = await base_handler(request)
|
||||
response = await _execute_model_async(model_call)
|
||||
else:
|
||||
# Call composed async handler with base handler
|
||||
output = await awrap_model_call_handler(request, base_handler)
|
||||
return {
|
||||
response = await awrap_model_call_handler(request, _execute_model_async)
|
||||
|
||||
# Build result dict with model call counts and messages
|
||||
result: dict[str, Any] = {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
"messages": response.result,
|
||||
}
|
||||
|
||||
# Add structured response if present
|
||||
if response.structured_response is not None:
|
||||
result["structured_response"] = response.structured_response
|
||||
|
||||
return result
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ from .tool_selection import LLMToolSelectorMiddleware
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -35,9 +37,11 @@ __all__ = [
|
||||
"ContextEditingMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelCall",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
|
||||
@@ -22,7 +22,12 @@ from langchain_core.messages import (
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
|
||||
@@ -209,11 +214,11 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Apply context edits before invoking the model via handler."""
|
||||
if not request.messages:
|
||||
return handler(request)
|
||||
if not request.model_call.messages:
|
||||
return handler(request.model_call)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
@@ -221,18 +226,20 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
[SystemMessage(content=request.model_call.system_prompt)]
|
||||
if request.model_call.system_prompt
|
||||
else []
|
||||
)
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
return request.model_call.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.model_call.tools
|
||||
)
|
||||
|
||||
for edit in self.edits:
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
edit.apply(request.model_call.messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
@@ -14,7 +16,6 @@ if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class ModelFallbackMiddleware(AgentMiddleware):
|
||||
@@ -68,18 +69,16 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Try fallback models in sequence on errors.
|
||||
|
||||
Args:
|
||||
request: Initial model request.
|
||||
state: Current agent state.
|
||||
runtime: LangGraph runtime.
|
||||
handler: Callback to execute the model.
|
||||
request: Full model request including state and runtime.
|
||||
handler: Callback to execute the model call.
|
||||
|
||||
Returns:
|
||||
AIMessage from successful model call.
|
||||
ModelResponse from successful model call.
|
||||
|
||||
Raises:
|
||||
Exception: If all models fail, re-raises last exception.
|
||||
@@ -87,15 +86,15 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
||||
# Try primary model first
|
||||
last_exception: Exception
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
|
||||
# Try fallback models
|
||||
for fallback_model in self.models:
|
||||
request.model = fallback_model
|
||||
request.model_call.model = fallback_model
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
@@ -8,12 +8,18 @@ from typing import TYPE_CHECKING, Annotated, Literal
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.tools import InjectedToolCallId
|
||||
|
||||
|
||||
@@ -189,12 +195,12 @@ class PlanningMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
request.system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
request.model_call.system_prompt = (
|
||||
request.model_call.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.model_call.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
@@ -4,9 +4,12 @@ from collections.abc import Callable
|
||||
from typing import Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
@@ -45,8 +48,8 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Modify the model request to add cache control blocks."""
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
@@ -61,10 +64,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
"Anthropic models. "
|
||||
"Please install langchain-anthropic."
|
||||
)
|
||||
elif not isinstance(request.model, ChatAnthropic):
|
||||
elif not isinstance(request.model_call.model, ChatAnthropic):
|
||||
msg = (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
||||
f"Anthropic models, not instances of {type(request.model)}"
|
||||
f"Anthropic models, not instances of {type(request.model_call.model)}"
|
||||
)
|
||||
|
||||
if msg is not None:
|
||||
@@ -73,14 +76,16 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
messages_count = (
|
||||
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
||||
len(request.model_call.messages) + 1
|
||||
if request.model_call.system_prompt
|
||||
else len(request.model_call.messages)
|
||||
)
|
||||
if messages_count < self.min_messages_to_cache:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
request.model_call.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
@@ -12,11 +12,16 @@ if TYPE_CHECKING:
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -142,11 +147,11 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
SelectionRequest with prepared inputs, or None if no selection is needed.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
if not request.model_call.tools or len(request.model_call.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
base_tools = [tool for tool in request.model_call.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
@@ -180,7 +185,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
for message in reversed(request.model_call.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
@@ -188,7 +193,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
model = self.model or request.model_call.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
@@ -205,8 +210,8 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest,
|
||||
) -> ModelRequest:
|
||||
"""Process the selection response and return filtered ModelRequest."""
|
||||
) -> None:
|
||||
"""Process the selection response and update ModelRequest with filtered tools."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
@@ -231,26 +236,25 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
]
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
for tool in request.model_call.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
provider_tools = [tool for tool in request.model_call.tools if isinstance(tool, dict)]
|
||||
|
||||
request.tools = [*selected_tools, *provider_tools]
|
||||
return request
|
||||
request.model_call.tools = [*selected_tools, *provider_tools]
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Filter tools based on LLM selection before invoking the model via handler."""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Create dynamic response model with Literal enum of available tool names
|
||||
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
||||
@@ -268,20 +272,20 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
if not isinstance(response, dict):
|
||||
msg = f"Expected dict response, got {type(response)}"
|
||||
raise AssertionError(msg)
|
||||
modified_request = self._process_selection_response(
|
||||
self._process_selection_response(
|
||||
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
||||
)
|
||||
return handler(modified_request)
|
||||
return handler(request.model_call)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Filter tools based on LLM selection before invoking the model via handler."""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
# Create dynamic response model with Literal enum of available tool names
|
||||
type_adapter = _create_tool_selection_response(selection_request.available_tools)
|
||||
@@ -299,7 +303,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
||||
if not isinstance(response, dict):
|
||||
msg = f"Expected dict response, got {type(response)}"
|
||||
raise AssertionError(msg)
|
||||
modified_request = self._process_selection_response(
|
||||
self._process_selection_response(
|
||||
response, selection_request.available_tools, selection_request.valid_tool_names, request
|
||||
)
|
||||
return await handler(modified_request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage # noqa: TC002
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, ToolMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.channels.untracked_value import UntrackedValue
|
||||
from langgraph.graph.message import add_messages
|
||||
@@ -41,7 +41,9 @@ __all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"ModelCall",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
"PublicAgentState",
|
||||
"after_agent",
|
||||
@@ -60,8 +62,11 @@ ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
class ModelCall:
|
||||
"""Model invocation parameters for a single model call.
|
||||
|
||||
Contains only the parameters needed to invoke the model, without agent context.
|
||||
"""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
@@ -69,9 +74,34 @@ class ModelRequest:
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
response_format: ResponseFormat | None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Full request context for model invocation including agent state.
|
||||
|
||||
Combines model invocation parameters with agent state and runtime context.
|
||||
"""
|
||||
|
||||
model_call: ModelCall
|
||||
state: AgentState
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Response from model execution including messages and optional structured output.
|
||||
|
||||
The result will usually contain a single AIMessage, but may include
|
||||
an additional ToolMessage if the model used a tool for structured output.
|
||||
"""
|
||||
|
||||
result: list[BaseMessage]
|
||||
"""List of messages from model execution."""
|
||||
|
||||
structured_response: Any = None
|
||||
"""Parsed structured output if response_format was specified, None otherwise."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -167,23 +197,23 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
The handler callback executes the model request and returns an AIMessage.
|
||||
Middleware can call the handler multiple times for retry logic, skip calling
|
||||
it to short-circuit, or modify the request/response. Multiple middleware
|
||||
compose with first in list as outermost layer.
|
||||
The handler callback executes the model call and returns a ModelResponse containing
|
||||
messages and optional structured_response. Middleware can call the handler multiple
|
||||
times for retry logic, skip calling it to short-circuit, or modify the request/response.
|
||||
Multiple middleware compose with first in list as outermost layer.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Callback that executes the model request and returns AIMessage.
|
||||
Call this to execute the model. Can be called multiple times
|
||||
for retry logic. Can skip calling it to short-circuit.
|
||||
request: Full model request including state and runtime context.
|
||||
handler: Callback that executes the model call and returns ModelResponse.
|
||||
Pass request.model_call to execute the model. Can be called
|
||||
multiple times for retry logic. Can skip calling it to short-circuit.
|
||||
|
||||
Returns:
|
||||
Final AIMessage to use (from handler or custom).
|
||||
Final ModelResponse to use (from handler or custom).
|
||||
|
||||
Examples:
|
||||
Retry on error:
|
||||
@@ -191,36 +221,40 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def wrap_model_call(self, request, handler):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
```
|
||||
|
||||
Rewrite response:
|
||||
Modify messages:
|
||||
```python
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=f"[{result.content}]")
|
||||
response = handler(request.model_call)
|
||||
# Modify first message (AIMessage)
|
||||
ai_msg = response.result[0]
|
||||
modified = AIMessage(content=f"[{ai_msg.content}]")
|
||||
return ModelResponse(
|
||||
result=[modified, *response.result[1:]],
|
||||
structured_response=response.structured_response,
|
||||
)
|
||||
```
|
||||
|
||||
Error to fallback:
|
||||
```python
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
return AIMessage(content="Service unavailable")
|
||||
return ModelResponse(result=[AIMessage(content="Service unavailable")])
|
||||
```
|
||||
|
||||
Cache/short-circuit:
|
||||
Modify model settings:
|
||||
```python
|
||||
def wrap_model_call(self, request, handler):
|
||||
if cached := get_cache(request):
|
||||
return cached # Short-circuit with cached result
|
||||
result = handler(request)
|
||||
save_cache(request, result)
|
||||
return result
|
||||
# Modify the model call parameters
|
||||
request.model_call.model_settings["temperature"] = 0.7
|
||||
return handler(request.model_call)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -228,16 +262,17 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""Async version of wrap_model_call.
|
||||
|
||||
Args:
|
||||
request: Model request to execute (includes state and runtime).
|
||||
handler: Async callback that executes the model request.
|
||||
request: Full model request including state and runtime context.
|
||||
handler: Async callback that executes the model call and returns ModelResponse.
|
||||
Pass request.model_call to execute the model.
|
||||
|
||||
Returns:
|
||||
Final AIMessage to use (from handler or custom).
|
||||
Final ModelResponse to use (from handler or custom).
|
||||
|
||||
Examples:
|
||||
Retry on error:
|
||||
@@ -245,7 +280,7 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
async def awrap_model_call(self, request, handler):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
except Exception:
|
||||
if attempt == 2:
|
||||
raise
|
||||
@@ -337,14 +372,14 @@ class _CallableReturningPromptString(Protocol[StateT_contra, ContextT]): # type
|
||||
class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
|
||||
"""Callable for model call interception with handler callback.
|
||||
|
||||
Receives handler callback to execute model and returns final AIMessage.
|
||||
Receives handler callback to execute model and returns final ModelResponse.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""Intercept model execution via handler callback."""
|
||||
...
|
||||
|
||||
@@ -1037,11 +1072,11 @@ def dynamic_prompt(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
prompt = await func(request) # type: ignore[misc]
|
||||
request.system_prompt = prompt
|
||||
return await handler(request)
|
||||
request.model_call.system_prompt = prompt
|
||||
return await handler(request.model_call)
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
|
||||
@@ -1058,11 +1093,11 @@ def dynamic_prompt(
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
prompt = cast("str", func(request))
|
||||
request.system_prompt = prompt
|
||||
return handler(request)
|
||||
request.model_call.system_prompt = prompt
|
||||
return handler(request.model_call)
|
||||
|
||||
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
||||
|
||||
@@ -1176,8 +1211,8 @@ def wrap_model_call(
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
return await func(request, handler) # type: ignore[misc, arg-type]
|
||||
|
||||
middleware_name = name or cast(
|
||||
@@ -1197,8 +1232,8 @@ def wrap_model_call(
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return func(request, handler)
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestLLMToolSelectorBasic:
|
||||
"""Middleware to select relevant tools based on state/context."""
|
||||
# Select a small, relevant subset of tools based on state/context
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
tool_selection_model = FakeModel(
|
||||
messages=cycle(
|
||||
@@ -161,7 +161,9 @@ class TestLLMToolSelectorBasic:
|
||||
assert isinstance(response["messages"][-1], AIMessage)
|
||||
|
||||
for request in model_requests:
|
||||
selected_tool_names = [tool.name for tool in request.tools] if request.tools else []
|
||||
selected_tool_names = (
|
||||
[tool.name for tool in request.model_call.tools] if request.model_call.tools else []
|
||||
)
|
||||
assert selected_tool_names == ["get_weather", "calculate"]
|
||||
|
||||
async def test_async_basic_selection(self) -> None:
|
||||
@@ -218,7 +220,7 @@ class TestMaxToolsLimiting:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector model tries to select 4 tools
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -261,8 +263,8 @@ class TestMaxToolsLimiting:
|
||||
# Verify only 2 tools were passed to the main model
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
assert len(request.tools) == 2
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
assert len(request.model_call.tools) == 2
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
# Should be first 2 from the selection
|
||||
assert tool_names == ["get_weather", "search_web"]
|
||||
|
||||
@@ -273,7 +275,7 @@ class TestMaxToolsLimiting:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
tool_selection_model = FakeModel(
|
||||
messages=cycle(
|
||||
@@ -315,8 +317,8 @@ class TestMaxToolsLimiting:
|
||||
# All 4 selected tools should be present
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
assert len(request.tools) == 4
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
assert len(request.model_call.tools) == 4
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert set(tool_names) == {
|
||||
"get_weather",
|
||||
"search_web",
|
||||
@@ -335,7 +337,7 @@ class TestAlwaysInclude:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector picks only search_web
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -373,7 +375,7 @@ class TestAlwaysInclude:
|
||||
# Both selected and always_include tools should be present
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert "search_web" in tool_names
|
||||
assert "send_email" in tool_names
|
||||
assert len(tool_names) == 2
|
||||
@@ -385,7 +387,7 @@ class TestAlwaysInclude:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector picks 2 tools
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -425,8 +427,8 @@ class TestAlwaysInclude:
|
||||
# Should have 2 selected + 2 always_include = 4 total
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
assert len(request.tools) == 4
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
assert len(request.model_call.tools) == 4
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert "get_weather" in tool_names
|
||||
assert "search_web" in tool_names
|
||||
assert "send_email" in tool_names
|
||||
@@ -439,7 +441,7 @@ class TestAlwaysInclude:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector picks 1 tool
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -478,8 +480,8 @@ class TestAlwaysInclude:
|
||||
# Should have 1 selected + 3 always_include = 4 total
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
assert len(request.tools) == 4
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
assert len(request.model_call.tools) == 4
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert "get_weather" in tool_names
|
||||
assert "send_email" in tool_names
|
||||
assert "calculate" in tool_names
|
||||
@@ -496,7 +498,7 @@ class TestDuplicateAndInvalidTools:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector returns duplicates
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -538,7 +540,7 @@ class TestDuplicateAndInvalidTools:
|
||||
# Duplicates should be removed
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert tool_names == ["get_weather", "search_web"]
|
||||
assert len(tool_names) == 2
|
||||
|
||||
@@ -549,7 +551,7 @@ class TestDuplicateAndInvalidTools:
|
||||
@wrap_model_call
|
||||
def trace_model_requests(request, handler):
|
||||
model_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Selector returns duplicates but max_tools=2
|
||||
tool_selection_model = FakeModel(
|
||||
@@ -592,7 +594,7 @@ class TestDuplicateAndInvalidTools:
|
||||
# Should deduplicate and respect max_tools
|
||||
assert len(model_requests) > 0
|
||||
for request in model_requests:
|
||||
tool_names = [tool.name for tool in request.tools]
|
||||
tool_names = [tool.name for tool in request.model_call.tools]
|
||||
assert len(tool_names) == 2
|
||||
assert "get_weather" in tool_names
|
||||
assert "search_web" in tool_names
|
||||
|
||||
@@ -8,7 +8,9 @@ from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
wrap_model_call,
|
||||
)
|
||||
|
||||
@@ -21,7 +23,7 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call
|
||||
def passthrough_middleware(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# Should return an AgentMiddleware instance
|
||||
assert isinstance(passthrough_middleware, AgentMiddleware)
|
||||
@@ -39,7 +41,7 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call(name="CustomMiddleware")
|
||||
def my_middleware(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
assert isinstance(my_middleware, AgentMiddleware)
|
||||
assert my_middleware.__class__.__name__ == "CustomMiddleware"
|
||||
@@ -58,10 +60,10 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def retry_once(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
# Retry once
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="Success")]))
|
||||
agent = create_agent(model=model, middleware=[retry_once])
|
||||
@@ -76,8 +78,8 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call
|
||||
def uppercase_responses(request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=result.content.upper())
|
||||
result = handler(request.model_call)
|
||||
return ModelResponse(result=[AIMessage(content=result.result[0].content.upper())])
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
agent = create_agent(model=model, middleware=[uppercase_responses])
|
||||
@@ -96,9 +98,9 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def error_to_fallback(request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
return AIMessage(content="Fallback response")
|
||||
return ModelResponse(result=[AIMessage(content="Fallback response")])
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[error_to_fallback])
|
||||
@@ -114,7 +116,7 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def log_state(request, handler):
|
||||
state_values.append(request.state.get("messages"))
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[log_state])
|
||||
@@ -133,14 +135,14 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def outer_middleware(request, handler):
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
@wrap_model_call
|
||||
def inner_middleware(request, handler):
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("inner-after")
|
||||
return result
|
||||
|
||||
@@ -166,7 +168,7 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState)
|
||||
def middleware_with_schema(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
assert isinstance(middleware_with_schema, AgentMiddleware)
|
||||
# Custom state schema should be set
|
||||
@@ -183,7 +185,7 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call(tools=[test_tool])
|
||||
def middleware_with_tools(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
assert isinstance(middleware_with_tools, AgentMiddleware)
|
||||
assert len(middleware_with_tools.tools) == 1
|
||||
@@ -195,12 +197,12 @@ class TestOnModelCallDecorator:
|
||||
# Without parentheses
|
||||
@wrap_model_call
|
||||
def middleware_no_parens(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
# With parentheses
|
||||
@wrap_model_call()
|
||||
def middleware_with_parens(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
assert isinstance(middleware_no_parens, AgentMiddleware)
|
||||
assert isinstance(middleware_with_parens, AgentMiddleware)
|
||||
@@ -210,7 +212,7 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call
|
||||
def my_custom_middleware(request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
assert my_custom_middleware.__class__.__name__ == "my_custom_middleware"
|
||||
|
||||
@@ -221,14 +223,14 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def decorated_middleware(request, handler):
|
||||
execution_order.append("decorated-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("decorated-after")
|
||||
return result
|
||||
|
||||
class ClassMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("class-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("class-after")
|
||||
return result
|
||||
|
||||
@@ -267,7 +269,7 @@ class TestOnModelCallDecorator:
|
||||
for attempt in range(max_retries):
|
||||
attempts.append(attempt + 1)
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
# On error, continue to next attempt
|
||||
@@ -291,7 +293,7 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
async def logging_middleware(request, handler):
|
||||
call_log.append("before")
|
||||
result = await handler(request)
|
||||
result = await handler(request.model_call)
|
||||
call_log.append("after")
|
||||
return result
|
||||
|
||||
@@ -310,18 +312,9 @@ class TestOnModelCallDecorator:
|
||||
@wrap_model_call
|
||||
def add_system_prompt(request, handler):
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
messages=request.messages,
|
||||
model=request.model,
|
||||
system_prompt="You are a helpful assistant",
|
||||
tool_choice=request.tool_choice,
|
||||
tools=request.tools,
|
||||
response_format=request.response_format,
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
modified_prompts.append(modified_request.system_prompt)
|
||||
return handler(modified_request)
|
||||
request.model_call.system_prompt = "You are a helpful assistant"
|
||||
modified_prompts.append(request.model_call.system_prompt)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[add_system_prompt])
|
||||
@@ -335,13 +328,13 @@ class TestOnModelCallDecorator:
|
||||
|
||||
@wrap_model_call
|
||||
def multi_transform(request, handler):
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
|
||||
# First transformation: uppercase
|
||||
content = result.content.upper()
|
||||
content = result.result[0].content.upper()
|
||||
# Second transformation: add prefix and suffix
|
||||
content = f"[START] {content} [END]"
|
||||
return AIMessage(content=content)
|
||||
return ModelResponse(result=[AIMessage(content=content)])
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello")]))
|
||||
agent = create_agent(model=model, middleware=[multi_transform])
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +21,7 @@ class TestBasicOnModelCall:
|
||||
|
||||
class PassthroughMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello")]))
|
||||
agent = create_agent(model=model, middleware=[PassthroughMiddleware()])
|
||||
@@ -37,7 +38,7 @@ class TestBasicOnModelCall:
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
call_log.append("before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
call_log.append("after")
|
||||
return result
|
||||
|
||||
@@ -59,7 +60,7 @@ class TestBasicOnModelCall:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
self.call_count += 1
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
counter = CountingMiddleware()
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Reply")]))
|
||||
@@ -91,11 +92,11 @@ class TestRetryMiddleware:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
except Exception:
|
||||
self.retry_count += 1
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
|
||||
retry_middleware = RetryOnceMiddleware()
|
||||
@@ -125,7 +126,7 @@ class TestRetryMiddleware:
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
@@ -152,8 +153,9 @@ class TestResponseRewriting:
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=result.content.upper())
|
||||
result = handler(request.model_call)
|
||||
ai_msg = result.result[0]
|
||||
return ModelResponse(result=[AIMessage(content=ai_msg.content.upper())])
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="hello world")]))
|
||||
agent = create_agent(model=model, middleware=[UppercaseMiddleware()])
|
||||
@@ -171,8 +173,9 @@ class TestResponseRewriting:
|
||||
self.prefix = prefix
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=f"{self.prefix}{result.content}")
|
||||
result = handler(request.model_call)
|
||||
ai_msg = result.result[0]
|
||||
return ModelResponse(result=[AIMessage(content=f"{self.prefix}{ai_msg.content}")])
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[PrefixMiddleware(prefix="[BOT]: ")])
|
||||
@@ -195,10 +198,9 @@ class TestErrorHandling:
|
||||
class ErrorToSuccessMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
fallback = AIMessage(content="Error handled gracefully")
|
||||
return fallback
|
||||
return ModelResponse(result=[AIMessage(content="Error handled gracefully")])
|
||||
|
||||
model = AlwaysFailModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[ErrorToSuccessMiddleware()])
|
||||
@@ -218,10 +220,11 @@ class TestErrorHandling:
|
||||
class SelectiveErrorMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except ConnectionError:
|
||||
fallback = AIMessage(content="Network issue, try again later")
|
||||
return fallback
|
||||
return ModelResponse(
|
||||
result=[AIMessage(content="Network issue, try again later")]
|
||||
)
|
||||
|
||||
model = SpecificErrorModel(messages=iter([]))
|
||||
agent = create_agent(model=model, middleware=[SelectiveErrorMiddleware()])
|
||||
@@ -238,13 +241,12 @@ class TestErrorHandling:
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
call_log.append("before-yield")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
call_log.append("after-yield-success")
|
||||
return result
|
||||
except Exception:
|
||||
call_log.append("caught-error")
|
||||
fallback = AIMessage(content="Recovered from error")
|
||||
return fallback
|
||||
return ModelResponse(result=[AIMessage(content="Recovered from error")])
|
||||
|
||||
# Test 1: Success path
|
||||
call_log.clear()
|
||||
@@ -281,14 +283,18 @@ class TestShortCircuit:
|
||||
class CachingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
# Simple cache key based on last message
|
||||
cache_key = str(request.messages[-1].content) if request.messages else ""
|
||||
cache_key = (
|
||||
str(request.model_call.messages[-1].content)
|
||||
if request.model_call.messages
|
||||
else ""
|
||||
)
|
||||
|
||||
if cache_key in cache:
|
||||
# Short-circuit with cached result
|
||||
return cache[cache_key]
|
||||
else:
|
||||
# Execute and cache
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
cache[cache_key] = result
|
||||
return result
|
||||
|
||||
@@ -337,19 +343,9 @@ class TestRequestModification:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
# Modify request to add system prompt
|
||||
modified_request = ModelRequest(
|
||||
model=request.model,
|
||||
system_prompt=self.system_prompt,
|
||||
messages=request.messages,
|
||||
tools=request.tools,
|
||||
tool_choice=request.tool_choice,
|
||||
response_format=request.response_format,
|
||||
model_settings=request.model_settings,
|
||||
state=request.state,
|
||||
runtime=request.runtime,
|
||||
)
|
||||
received_requests.append(modified_request)
|
||||
return handler(modified_request)
|
||||
request.model_call.system_prompt = self.system_prompt
|
||||
received_requests.append(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(
|
||||
@@ -360,7 +356,7 @@ class TestRequestModification:
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert len(received_requests) == 1
|
||||
assert received_requests[0].system_prompt == "You are a helpful assistant."
|
||||
assert received_requests[0].model_call.system_prompt == "You are a helpful assistant."
|
||||
assert result["messages"][1].content == "Response"
|
||||
|
||||
|
||||
@@ -380,7 +376,7 @@ class TestStateAndRuntime:
|
||||
"messages_count": len(request.state.get("messages", [])),
|
||||
}
|
||||
)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[StateAwareMiddleware()])
|
||||
@@ -399,7 +395,7 @@ class TestStateAndRuntime:
|
||||
max_retries = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
break # Success
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
@@ -433,14 +429,14 @@ class TestMiddlewareComposition:
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("outer-before")
|
||||
response = handler(request)
|
||||
response = handler(request.model_call)
|
||||
execution_order.append("outer-after")
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("inner-before")
|
||||
response = handler(request)
|
||||
response = handler(request.model_call)
|
||||
execution_order.append("inner-after")
|
||||
return response
|
||||
|
||||
@@ -472,7 +468,7 @@ class TestMiddlewareComposition:
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
log.append("logging-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
log.append("logging-after")
|
||||
return result
|
||||
|
||||
@@ -480,12 +476,12 @@ class TestMiddlewareComposition:
|
||||
def wrap_model_call(self, request, handler):
|
||||
log.append("retry-before")
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
log.append("retry-after")
|
||||
return result
|
||||
except Exception:
|
||||
log.append("retry-retrying")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
log.append("retry-after")
|
||||
return result
|
||||
|
||||
@@ -510,13 +506,15 @@ class TestMiddlewareComposition:
|
||||
|
||||
class PrefixMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=f"[PREFIX] {result.content}")
|
||||
result = handler(request.model_call)
|
||||
ai_msg = result.result[0]
|
||||
return ModelResponse(result=[AIMessage(content=f"[PREFIX] {ai_msg.content}")])
|
||||
|
||||
class SuffixMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=f"{result.content} [SUFFIX]")
|
||||
result = handler(request.model_call)
|
||||
ai_msg = result.result[0]
|
||||
return ModelResponse(result=[AIMessage(content=f"{ai_msg.content} [SUFFIX]")])
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Middle")]))
|
||||
# Prefix is outer, Suffix is inner
|
||||
@@ -542,16 +540,17 @@ class TestMiddlewareComposition:
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
except Exception:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
|
||||
class UppercaseMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
return AIMessage(content=result.content.upper())
|
||||
result = handler(request.model_call)
|
||||
ai_msg = result.result[0]
|
||||
return ModelResponse(result=[AIMessage(content=ai_msg.content.upper())])
|
||||
|
||||
model = FailOnceThenSucceed(messages=iter([AIMessage(content="success")]))
|
||||
# Retry outer, Uppercase inner
|
||||
@@ -569,21 +568,21 @@ class TestMiddlewareComposition:
|
||||
class FirstMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("first-before")
|
||||
response = handler(request)
|
||||
response = handler(request.model_call)
|
||||
execution_order.append("first-after")
|
||||
return response
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("second-before")
|
||||
response = handler(request)
|
||||
response = handler(request.model_call)
|
||||
execution_order.append("second-after")
|
||||
return response
|
||||
|
||||
class ThirdMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("third-before")
|
||||
response = handler(request)
|
||||
response = handler(request.model_call)
|
||||
execution_order.append("third-after")
|
||||
return response
|
||||
|
||||
@@ -613,7 +612,7 @@ class TestMiddlewareComposition:
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("outer-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("outer-after")
|
||||
return result
|
||||
|
||||
@@ -621,16 +620,16 @@ class TestMiddlewareComposition:
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("middle-before")
|
||||
# Always retry once (call handler twice)
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("middle-retry")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("middle-after")
|
||||
return result
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
execution_order.append("inner-before")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
execution_order.append("inner-after")
|
||||
return result
|
||||
|
||||
@@ -675,7 +674,7 @@ class TestAsyncOnModelCall:
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(self, request, handler):
|
||||
log.append("before")
|
||||
result = await handler(request)
|
||||
result = await handler(request.model_call)
|
||||
log.append("after")
|
||||
|
||||
return result
|
||||
@@ -702,9 +701,9 @@ class TestAsyncOnModelCall:
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
async def awrap_model_call(self, request, handler):
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
except Exception:
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
model = AsyncFailOnceThenSucceed(messages=iter([AIMessage(content="Async success")]))
|
||||
agent = create_agent(model=model, middleware=[RetryMiddleware()])
|
||||
@@ -725,9 +724,8 @@ class TestEdgeCases:
|
||||
class RequestModifyingMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
# Add a system message to the request
|
||||
modified_request = request
|
||||
modified_messages.append(len(modified_request.messages))
|
||||
return handler(modified_request)
|
||||
modified_messages.append(len(request.model_call.messages))
|
||||
return handler(request.model_call)
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
agent = create_agent(model=model, middleware=[RequestModifyingMiddleware()])
|
||||
@@ -744,11 +742,11 @@ class TestEdgeCases:
|
||||
def wrap_model_call(self, request, handler):
|
||||
attempts.append("first-attempt")
|
||||
try:
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
except Exception:
|
||||
attempts.append("retry-attempt")
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
return result
|
||||
|
||||
call_count = {"value": 0}
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import AgentState, ModelCall, ModelRequest, ModelResponse
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -55,16 +55,19 @@ def _make_state_and_request(
|
||||
model = _TokenCountingChatModel()
|
||||
conversation = list(messages)
|
||||
state = cast(AgentState, {"messages": conversation})
|
||||
request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
messages=conversation,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state=state,
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
return state, request
|
||||
|
||||
@@ -82,16 +85,16 @@ def test_no_edit_when_below_trigger() -> None:
|
||||
edits=[ClearToolUsesEdit(trigger=50)],
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# The request should have been modified in place
|
||||
assert request.messages[0].content == ""
|
||||
assert request.messages[1].content == "12345"
|
||||
assert state["messages"] == request.messages
|
||||
assert request.model_call.messages[0].content == ""
|
||||
assert request.model_call.messages[1].content == "12345"
|
||||
assert state["messages"] == request.model_call.messages
|
||||
|
||||
|
||||
def test_clear_tool_outputs_and_inputs() -> None:
|
||||
@@ -115,14 +118,14 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_ai = request.messages[0]
|
||||
cleared_tool = request.messages[1]
|
||||
cleared_ai = request.model_call.messages[0]
|
||||
cleared_tool = request.model_call.messages[1]
|
||||
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared output]"
|
||||
@@ -134,7 +137,7 @@ def test_clear_tool_outputs_and_inputs() -> None:
|
||||
assert context_meta is not None
|
||||
assert context_meta["cleared_tool_inputs"] == [tool_call_id]
|
||||
|
||||
assert state["messages"] == request.messages
|
||||
assert state["messages"] == request.model_call.messages
|
||||
|
||||
|
||||
def test_respects_keep_last_tool_results() -> None:
|
||||
@@ -167,21 +170,21 @@ def test_respects_keep_last_tool_results() -> None:
|
||||
token_count_method="model",
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
for msg in request.model_call.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
assert len(cleared_messages) == 2
|
||||
assert isinstance(request.messages[-1], ToolMessage)
|
||||
assert request.messages[-1].content != "[cleared]"
|
||||
assert isinstance(request.model_call.messages[-1], ToolMessage)
|
||||
assert request.model_call.messages[-1].content != "[cleared]"
|
||||
|
||||
|
||||
def test_exclude_tools_prevents_clearing() -> None:
|
||||
@@ -215,14 +218,14 @@ def test_exclude_tools_prevents_clearing() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call which modifies the request
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
search_tool = request.messages[1]
|
||||
calc_tool = request.messages[3]
|
||||
search_tool = request.model_call.messages[1]
|
||||
calc_tool = request.model_call.messages[3]
|
||||
|
||||
assert isinstance(search_tool, ToolMessage)
|
||||
assert search_tool.content == "search-results" * 20
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain.agents.factory import _chain_model_call_handlers
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from langchain.agents.middleware.types import ModelCall, ModelRequest, ModelResponse
|
||||
|
||||
from typing import cast
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -13,18 +13,34 @@ from langgraph.runtime import Runtime
|
||||
def create_test_request(**kwargs):
|
||||
"""Helper to create a ModelRequest with sensible defaults."""
|
||||
|
||||
defaults = {
|
||||
model_call_defaults = {
|
||||
"messages": [],
|
||||
"model": None,
|
||||
"system_prompt": None,
|
||||
"tool_choice": None,
|
||||
"tools": [],
|
||||
"response_format": None,
|
||||
"model_settings": {},
|
||||
}
|
||||
|
||||
# Extract model_call fields from kwargs
|
||||
model_call_kwargs = {
|
||||
k: kwargs.pop(k, v)
|
||||
for k, v in model_call_defaults.items()
|
||||
if k in kwargs or k in model_call_defaults
|
||||
}
|
||||
|
||||
# Create ModelCall
|
||||
model_call = ModelCall(**model_call_kwargs)
|
||||
|
||||
# Create ModelRequest with remaining kwargs
|
||||
request_defaults = {
|
||||
"model_call": model_call,
|
||||
"state": {},
|
||||
"runtime": cast(Runtime, object()),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ModelRequest(**defaults)
|
||||
request_defaults.update(kwargs)
|
||||
return ModelRequest(**request_defaults)
|
||||
|
||||
|
||||
class TestChainModelCallHandlers:
|
||||
@@ -65,7 +81,7 @@ class TestChainModelCallHandlers:
|
||||
|
||||
# Execute the composed handler
|
||||
def mock_base_handler(req):
|
||||
return AIMessage(content="test")
|
||||
return ModelResponse(result=[AIMessage(content="test")])
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
@@ -75,7 +91,7 @@ class TestChainModelCallHandlers:
|
||||
"inner-after",
|
||||
"outer-after",
|
||||
]
|
||||
assert result.content == "test"
|
||||
assert result.result[0].content == "test"
|
||||
|
||||
def test_three_handlers_composition(self) -> None:
|
||||
"""Test composition of three handlers."""
|
||||
@@ -103,7 +119,7 @@ class TestChainModelCallHandlers:
|
||||
assert composed is not None
|
||||
|
||||
def mock_base_handler(req):
|
||||
return AIMessage(content="test")
|
||||
return ModelResponse(result=[AIMessage(content="test")])
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
@@ -116,7 +132,7 @@ class TestChainModelCallHandlers:
|
||||
"second-after",
|
||||
"first-after",
|
||||
]
|
||||
assert result.content == "test"
|
||||
assert result.result[0].content == "test"
|
||||
|
||||
def test_inner_handler_retry(self) -> None:
|
||||
"""Test inner handler retrying before outer sees response."""
|
||||
@@ -144,12 +160,12 @@ class TestChainModelCallHandlers:
|
||||
call_count["value"] += 1
|
||||
if call_count["value"] < 3:
|
||||
raise ValueError("fail")
|
||||
return AIMessage(content="success")
|
||||
return ModelResponse(result=[AIMessage(content="success")])
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
assert inner_attempts == [0, 1, 2]
|
||||
assert result.content == "success"
|
||||
assert result.result[0].content == "success"
|
||||
|
||||
def test_error_to_success_conversion(self) -> None:
|
||||
"""Test handler converting error to success response."""
|
||||
@@ -158,7 +174,7 @@ class TestChainModelCallHandlers:
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception:
|
||||
return AIMessage(content="Fallback response")
|
||||
return ModelResponse(result=[AIMessage(content="Fallback response")])
|
||||
|
||||
def inner_passthrough(request, handler):
|
||||
return handler(request)
|
||||
@@ -171,32 +187,32 @@ class TestChainModelCallHandlers:
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
assert result.content == "Fallback response"
|
||||
assert result.result[0].content == "Fallback response"
|
||||
|
||||
def test_request_modification(self) -> None:
|
||||
"""Test handlers modifying the request."""
|
||||
requests_seen = []
|
||||
|
||||
def outer_add_context(request, handler):
|
||||
modified_request = create_test_request(
|
||||
messages=[*request.messages], system_prompt="Added by outer"
|
||||
)
|
||||
return handler(modified_request)
|
||||
# Modify the model_call
|
||||
request.model_call.system_prompt = "Added by outer"
|
||||
return handler(request.model_call)
|
||||
|
||||
def inner_track_request(request, handler):
|
||||
requests_seen.append(request.system_prompt)
|
||||
return handler(request)
|
||||
# Inner handler receives ModelRequest due to composition
|
||||
requests_seen.append(request.model_call.system_prompt)
|
||||
return handler(request.model_call)
|
||||
|
||||
composed = _chain_model_call_handlers([outer_add_context, inner_track_request])
|
||||
assert composed is not None
|
||||
|
||||
def mock_base_handler(req):
|
||||
return AIMessage(content="response")
|
||||
return ModelResponse(result=[AIMessage(content="response")])
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
assert requests_seen == ["Added by outer"]
|
||||
assert result.content == "response"
|
||||
assert result.result[0].content == "response"
|
||||
|
||||
def test_composition_preserves_state_and_runtime(self) -> None:
|
||||
"""Test that state and runtime are passed through composition."""
|
||||
@@ -220,7 +236,7 @@ class TestChainModelCallHandlers:
|
||||
test_runtime = {"test": "runtime"}
|
||||
|
||||
def mock_base_handler(req):
|
||||
return AIMessage(content="test")
|
||||
return ModelResponse(result=[AIMessage(content="test")])
|
||||
|
||||
# Create request with state and runtime
|
||||
test_request = create_test_request()
|
||||
@@ -231,7 +247,7 @@ class TestChainModelCallHandlers:
|
||||
# Both handlers should see same state and runtime
|
||||
assert state_values == [("outer", test_state), ("inner", test_state)]
|
||||
assert runtime_values == [("outer", test_runtime), ("inner", test_runtime)]
|
||||
assert result.content == "test"
|
||||
assert result.result[0].content == "test"
|
||||
|
||||
def test_multiple_yields_in_retry_loop(self) -> None:
|
||||
"""Test handler that retries multiple times."""
|
||||
@@ -257,11 +273,11 @@ class TestChainModelCallHandlers:
|
||||
attempt["value"] += 1
|
||||
if attempt["value"] == 1:
|
||||
raise ValueError("fail")
|
||||
return AIMessage(content="ok")
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
result = composed(create_test_request(), mock_base_handler)
|
||||
|
||||
# Outer called once, inner retried so base handler called twice
|
||||
assert call_count["value"] == 1
|
||||
assert attempt["value"] == 2
|
||||
assert result.content == "ok"
|
||||
assert result.result[0].content == "ok"
|
||||
|
||||
@@ -50,7 +50,9 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
hook_config,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
@@ -118,9 +120,9 @@ def test_create_agent_diagram(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
@@ -132,9 +134,9 @@ def test_create_agent_diagram(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
return handler(request)
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
pass
|
||||
@@ -260,10 +262,10 @@ def test_create_agent_invoke(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("NoopSeven.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
calls.append("NoopSeven.after_model")
|
||||
@@ -275,10 +277,10 @@ def test_create_agent_invoke(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("NoopEight.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
calls.append("NoopEight.after_model")
|
||||
@@ -361,10 +363,10 @@ def test_create_agent_jump(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("NoopSeven.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
calls.append("NoopSeven.after_model")
|
||||
@@ -378,10 +380,10 @@ def test_create_agent_jump(
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("NoopEight.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime):
|
||||
calls.append("NoopEight.after_model")
|
||||
@@ -1032,46 +1034,54 @@ def test_anthropic_prompt_caching_middleware_initialization() -> None:
|
||||
assert middleware.ttl == "5m"
|
||||
assert middleware.min_messages_to_cache == 0
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
fake_request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response", **req.model_settings)
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response", **req.model_settings)])
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
# Check that model_settings were passed through via the request
|
||||
assert fake_request.model_settings == {"cache_control": {"type": "ephemeral", "ttl": "5m"}}
|
||||
assert fake_request.model_call.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"""Test AnthropicPromptCachingMiddleware with unsupported model."""
|
||||
from typing import cast
|
||||
|
||||
fake_request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model=FakeToolCallingModel(),
|
||||
messages=[HumanMessage("Hello")],
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
model_settings={},
|
||||
)
|
||||
fake_request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="raise")
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
@@ -1102,12 +1112,12 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models. Please install langchain-anthropic."
|
||||
in str(w[-1].message)
|
||||
)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": langchain_anthropic}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result, ModelResponse)
|
||||
assert len(w) == 1
|
||||
assert (
|
||||
"AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, not instances of"
|
||||
@@ -1117,11 +1127,11 @@ def test_anthropic_prompt_caching_middleware_unsupported_model() -> None:
|
||||
middleware = AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore")
|
||||
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
with patch.dict("sys.modules", {"langchain_anthropic": {"ChatAnthropic": object()}}):
|
||||
result = middleware.wrap_model_call(fake_request, mock_handler)
|
||||
assert isinstance(result, AIMessage)
|
||||
assert isinstance(result, ModelResponse)
|
||||
|
||||
|
||||
# Tests for SummarizationMiddleware
|
||||
@@ -1346,10 +1356,10 @@ def test_on_model_call() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
request.messages.append(HumanMessage("remember to be nice!"))
|
||||
return handler(request)
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
request.model_call.messages.append(HumanMessage("remember to be nice!"))
|
||||
return handler(request.model_call)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -1482,10 +1492,10 @@ def test_runtime_injected_into_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
assert request.runtime is not None
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> None:
|
||||
assert runtime is not None
|
||||
@@ -1581,25 +1591,28 @@ def test_planning_middleware_on_model_call(original_prompt, expected_prompt_pref
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
|
||||
request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model=model,
|
||||
system_prompt=original_prompt,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state=state,
|
||||
runtime=cast(Runtime, object()),
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt.startswith(expected_prompt_prefix)
|
||||
assert request.model_call.system_prompt.startswith(expected_prompt_prefix)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1725,13 +1738,13 @@ def test_planning_middleware_custom_system_prompt() -> None:
|
||||
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
assert request.model_call.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
|
||||
|
||||
def test_planning_middleware_custom_tool_description() -> None:
|
||||
@@ -1757,25 +1770,28 @@ def test_planning_middleware_custom_system_prompt_and_tool_description() -> None
|
||||
model = FakeToolCallingModel()
|
||||
state: PlanningState = {"messages": [HumanMessage(content="Hello")]}
|
||||
|
||||
request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[HumanMessage(content="Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state=state,
|
||||
runtime=cast(Runtime, object()),
|
||||
)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
def mock_handler(req: ModelCall) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock response")])
|
||||
|
||||
# Call wrap_model_call to trigger the middleware logic
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the request was modified in place
|
||||
assert request.system_prompt == custom_system_prompt
|
||||
assert request.model_call.system_prompt == custom_system_prompt
|
||||
|
||||
# Verify tool description
|
||||
assert len(middleware.tools) == 1
|
||||
@@ -2047,11 +2063,11 @@ async def test_create_agent_async_invoke() -> None:
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
calls.append("AsyncMiddleware.awrap_model_call")
|
||||
request.messages.append(HumanMessage("async middleware message"))
|
||||
return await handler(request)
|
||||
request.model_call.messages.append(HumanMessage("async middleware message"))
|
||||
return await handler(request.model_call)
|
||||
|
||||
async def aafter_model(self, state, runtime) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
@@ -2108,10 +2124,10 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
calls.append("AsyncMiddlewareOne.awrap_model_call")
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
async def aafter_model(self, state, runtime) -> None:
|
||||
calls.append("AsyncMiddlewareOne.aafter_model")
|
||||
@@ -2123,10 +2139,10 @@ async def test_create_agent_async_invoke_multiple_middleware() -> None:
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
calls.append("AsyncMiddlewareTwo.awrap_model_call")
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
async def aafter_model(self, state, runtime) -> None:
|
||||
calls.append("AsyncMiddlewareTwo.aafter_model")
|
||||
@@ -2196,10 +2212,10 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("SyncMiddleware.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
def after_model(self, state, runtime) -> None:
|
||||
calls.append("SyncMiddleware.after_model")
|
||||
@@ -2211,10 +2227,10 @@ async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
calls.append("AsyncMiddleware.awrap_model_call")
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
async def aafter_model(self, state, runtime) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
@@ -2267,11 +2283,11 @@ def test_wrap_model_call_hook() -> None:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception:
|
||||
# Retry on error
|
||||
self.retry_count += 1
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
failing_model = FailingModel()
|
||||
retry_middleware = RetryMiddleware()
|
||||
@@ -2311,7 +2327,7 @@ def test_wrap_model_call_retry_count() -> None:
|
||||
for attempt in range(max_retries):
|
||||
self.attempts.append(attempt + 1)
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
@@ -2348,7 +2364,7 @@ def test_wrap_model_call_no_retry() -> None:
|
||||
|
||||
class NoRetryMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(self, request, handler):
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
agent = create_agent(model=FailingModel(), middleware=[NoRetryMiddleware()])
|
||||
|
||||
@@ -2469,7 +2485,7 @@ def test_wrap_model_call_max_attempts() -> None:
|
||||
for attempt in range(self.max_retries):
|
||||
self.attempt_count += 1
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
# Continue to retry
|
||||
@@ -2520,11 +2536,11 @@ async def test_wrap_model_call_async() -> None:
|
||||
|
||||
async def awrap_model_call(self, request, handler):
|
||||
try:
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
except Exception:
|
||||
# Retry on error
|
||||
self.retry_count += 1
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
failing_model = AsyncFailingModel()
|
||||
retry_middleware = AsyncRetryMiddleware()
|
||||
@@ -2559,10 +2575,12 @@ def test_wrap_model_call_rewrite_response() -> None:
|
||||
"""Middleware that rewrites the response."""
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
result = handler(request)
|
||||
result = handler(request.model_call)
|
||||
|
||||
# Rewrite the response
|
||||
return AIMessage(content=f"REWRITTEN: {result.content}")
|
||||
return ModelResponse(
|
||||
result=[AIMessage(content=f"REWRITTEN: {result.result[0].content}")]
|
||||
)
|
||||
|
||||
model = SimpleModel()
|
||||
middleware = ResponseRewriteMiddleware()
|
||||
@@ -2593,10 +2611,12 @@ def test_wrap_model_call_convert_error_to_response() -> None:
|
||||
|
||||
def wrap_model_call(self, request, handler):
|
||||
try:
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
except Exception as e:
|
||||
# Convert error to success response
|
||||
return AIMessage(content=f"Error occurred: {e}. Using fallback response.")
|
||||
return ModelResponse(
|
||||
result=[AIMessage(content=f"Error occurred: {e}. Using fallback response.")]
|
||||
)
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
middleware = ErrorToResponseMiddleware()
|
||||
@@ -2619,7 +2639,7 @@ def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> N
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
) -> ModelResponse:
|
||||
return await handler(request)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -2649,16 +2669,16 @@ def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
calls.append("MixedMiddleware.wrap_model_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[AIMessage]],
|
||||
) -> AIMessage:
|
||||
) -> ModelResponse:
|
||||
calls.append("MixedMiddleware.awrap_model_call")
|
||||
return await handler(request)
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ from langgraph.types import Command
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
before_model,
|
||||
after_model,
|
||||
dynamic_prompt,
|
||||
@@ -89,8 +91,8 @@ def test_on_model_call_decorator() -> None:
|
||||
|
||||
@wrap_model_call(state_schema=CustomState, tools=[test_tool], name="CustomOnModelCall")
|
||||
def custom_on_model_call(request, handler):
|
||||
request.system_prompt = "Modified"
|
||||
return handler(request)
|
||||
request.model_call.system_prompt = "Modified"
|
||||
return handler(request.model_call)
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_on_model_call, AgentMiddleware)
|
||||
@@ -99,22 +101,27 @@ def test_on_model_call_decorator() -> None:
|
||||
assert custom_on_model_call.__class__.__name__ == "CustomOnModelCall"
|
||||
|
||||
# Verify it works
|
||||
original_request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model="test-model",
|
||||
system_prompt="Original",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
original_request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
def mock_handler(req):
|
||||
return AIMessage(content=f"Handled with prompt: {req.system_prompt}")
|
||||
return ModelResponse(
|
||||
result=[AIMessage(content=f"Handled with prompt: {req.system_prompt}")]
|
||||
)
|
||||
|
||||
result = custom_on_model_call.wrap_model_call(original_request, mock_handler)
|
||||
assert result.content == "Handled with prompt: Modified"
|
||||
assert result.result[0].content == "Handled with prompt: Modified"
|
||||
|
||||
|
||||
def test_all_decorators_integration() -> None:
|
||||
@@ -129,7 +136,7 @@ def test_all_decorators_integration() -> None:
|
||||
@wrap_model_call
|
||||
def track_on_call(request, handler):
|
||||
call_order.append("on_call")
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
@after_model
|
||||
def track_after(state: AgentState, runtime: Runtime) -> None:
|
||||
@@ -324,7 +331,7 @@ async def test_async_decorators_integration() -> None:
|
||||
@wrap_model_call
|
||||
async def track_async_on_call(request, handler):
|
||||
call_order.append("async_on_call")
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
|
||||
@@ -364,7 +371,7 @@ async def test_mixed_sync_async_decorators_integration() -> None:
|
||||
@wrap_model_call
|
||||
async def track_async_on_call(request, handler):
|
||||
call_order.append("async_on_call")
|
||||
return await handler(request)
|
||||
return await handler(request.model_call)
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState, runtime: Runtime) -> None:
|
||||
@@ -581,22 +588,25 @@ def test_dynamic_prompt_decorator() -> None:
|
||||
assert my_prompt.__class__.__name__ == "my_prompt"
|
||||
|
||||
# Verify it modifies the request correctly
|
||||
original_request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model="test-model",
|
||||
system_prompt="Original",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
original_request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state={"messages": [HumanMessage("Hello")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
def mock_handler(req):
|
||||
return AIMessage(content=req.system_prompt)
|
||||
return ModelResponse(result=[AIMessage(content=req.system_prompt)])
|
||||
|
||||
result = my_prompt.wrap_model_call(original_request, mock_handler)
|
||||
assert result.content == "Dynamic test prompt"
|
||||
assert result.result[0].content == "Dynamic test prompt"
|
||||
|
||||
|
||||
def test_dynamic_prompt_uses_state() -> None:
|
||||
@@ -608,22 +618,25 @@ def test_dynamic_prompt_uses_state() -> None:
|
||||
return f"Prompt with {msg_count} messages"
|
||||
|
||||
# Verify it uses state correctly
|
||||
original_request = ModelRequest(
|
||||
model_call = ModelCall(
|
||||
model="test-model",
|
||||
system_prompt="Original",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
original_request = ModelRequest(
|
||||
model_call=model_call,
|
||||
state={"messages": [HumanMessage("Hello"), HumanMessage("World")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
def mock_handler(req):
|
||||
return AIMessage(content=req.system_prompt)
|
||||
return ModelResponse(result=[AIMessage(content=req.system_prompt)])
|
||||
|
||||
result = custom_prompt.wrap_model_call(original_request, mock_handler)
|
||||
assert result.content == "Prompt with 2 messages"
|
||||
assert result.result[0].content == "Prompt with 2 messages"
|
||||
|
||||
|
||||
def test_dynamic_prompt_integration() -> None:
|
||||
|
||||
@@ -3,7 +3,13 @@
|
||||
import pytest
|
||||
from collections.abc import Callable
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.tools import ToolNode
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
@@ -30,10 +36,10 @@ def test_model_request_tools_are_base_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
captured_requests.append(request)
|
||||
return handler(request)
|
||||
return handler(request.model_call)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -49,9 +55,9 @@ def test_model_request_tools_are_base_tools() -> None:
|
||||
|
||||
# Check that tools in the request are BaseTool objects
|
||||
request = captured_requests[0]
|
||||
assert isinstance(request.tools, list)
|
||||
assert len(request.tools) == 2
|
||||
assert {t.name for t in request.tools} == {"search_tool", "calculator"}
|
||||
assert isinstance(request.model_call.tools, list)
|
||||
assert len(request.model_call.tools) == 2
|
||||
assert {t.name for t in request.model_call.tools} == {"search_tool", "calculator"}
|
||||
|
||||
|
||||
def test_middleware_can_modify_tools() -> None:
|
||||
@@ -76,11 +82,13 @@ def test_middleware_can_modify_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Only allow tool_a and tool_b
|
||||
request.tools = [t for t in request.tools if t.name in ["tool_a", "tool_b"]]
|
||||
return handler(request)
|
||||
request.model_call.tools = [
|
||||
t for t in request.model_call.tools if t.name in ["tool_a", "tool_b"]
|
||||
]
|
||||
return handler(request.model_call)
|
||||
|
||||
# Model will try to call tool_a
|
||||
model = FakeToolCallingModel(
|
||||
@@ -121,11 +129,11 @@ def test_unknown_tool_raises_error() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Add an unknown tool
|
||||
request.tools = request.tools + [unknown_tool]
|
||||
return handler(request)
|
||||
request.model_call.tools = request.model_call.tools + [unknown_tool]
|
||||
return handler(request.model_call)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
@@ -160,12 +168,14 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Remove admin_tool if not admin
|
||||
if not request.state.get("is_admin", False):
|
||||
request.tools = [t for t in request.tools if t.name != "admin_tool"]
|
||||
return handler(request)
|
||||
request.model_call.tools = [
|
||||
t for t in request.model_call.tools if t.name != "admin_tool"
|
||||
]
|
||||
return handler(request.model_call)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
@@ -198,11 +208,11 @@ def test_empty_tools_list_is_valid() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Remove all tools
|
||||
request.tools = []
|
||||
return handler(request)
|
||||
request.model_call.tools = []
|
||||
return handler(request.model_call)
|
||||
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
@@ -241,25 +251,25 @@ def test_tools_preserved_across_multiple_middleware() -> None:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
modification_order.append([t.name for t in request.model_call.tools])
|
||||
# Remove tool_c
|
||||
request.tools = [t for t in request.tools if t.name != "tool_c"]
|
||||
return handler(request)
|
||||
request.model_call.tools = [t for t in request.model_call.tools if t.name != "tool_c"]
|
||||
return handler(request.model_call)
|
||||
|
||||
class SecondMiddleware(AgentMiddleware):
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], AIMessage],
|
||||
) -> AIMessage:
|
||||
modification_order.append([t.name for t in request.tools])
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
modification_order.append([t.name for t in request.model_call.tools])
|
||||
# Should not see tool_c here
|
||||
assert all(t.name != "tool_c" for t in request.tools)
|
||||
assert all(t.name != "tool_c" for t in request.model_call.tools)
|
||||
# Remove tool_b
|
||||
request.tools = [t for t in request.tools if t.name != "tool_b"]
|
||||
return handler(request)
|
||||
request.model_call.tools = [t for t in request.model_call.tools if t.name != "tool_b"]
|
||||
return handler(request.model_call)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
|
||||
@@ -698,7 +698,12 @@ class TestDynamicModelWithResponseFormat:
|
||||
selected based on the final model's capabilities.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCall,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
|
||||
# Custom model that we'll use to test whether the tool strategy is applied
|
||||
@@ -730,11 +735,11 @@ class TestDynamicModelWithResponseFormat:
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], CoreAIMessage],
|
||||
) -> CoreAIMessage:
|
||||
handler: Callable[[ModelCall], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
# Replace the model with our custom test model
|
||||
request.model = model
|
||||
return handler(request)
|
||||
request.model_call.model = model
|
||||
return handler(request.model_call)
|
||||
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
Reference in New Issue
Block a user