mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
revert: Support for SystemMessage in create_agent (#33889)
Reverts langchain-ai/langchain#33640 Introduces lint errors into langchain-anthropic Should incorporate into 1.1 instead of patch release.
This commit is contained in:
@@ -516,7 +516,7 @@ def create_agent( # noqa: PLR0915
|
|||||||
model: str | BaseChatModel,
|
model: str | BaseChatModel,
|
||||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||||
*,
|
*,
|
||||||
system_prompt: str | SystemMessage | None = None,
|
system_prompt: str | None = None,
|
||||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||||
@@ -548,9 +548,11 @@ def create_agent( # noqa: PLR0915
|
|||||||
|
|
||||||
If `None` or an empty list, the agent will consist of a model node without a
|
If `None` or an empty list, the agent will consist of a model node without a
|
||||||
tool calling loop.
|
tool calling loop.
|
||||||
system_prompt: An optional system prompt for the LLM or
|
system_prompt: An optional system prompt for the LLM.
|
||||||
can already be a [`SystemMessage`][langchain.messages.SystemMessage] object.
|
|
||||||
|
|
||||||
|
Prompts are converted to a
|
||||||
|
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
|
||||||
|
beginning of the message list.
|
||||||
middleware: A sequence of middleware instances to apply to the agent.
|
middleware: A sequence of middleware instances to apply to the agent.
|
||||||
|
|
||||||
Middleware can intercept and modify agent behavior at various stages. See
|
Middleware can intercept and modify agent behavior at various stages. See
|
||||||
@@ -1038,10 +1040,8 @@ def create_agent( # noqa: PLR0915
|
|||||||
# Get the bound model (with auto-detection if needed)
|
# Get the bound model (with auto-detection if needed)
|
||||||
model_, effective_response_format = _get_bound_model(request)
|
model_, effective_response_format = _get_bound_model(request)
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
|
if request.system_prompt:
|
||||||
messages = [SystemMessage(content=request.system_prompt), *messages]
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
|
|
||||||
messages = [request.system_prompt, *messages]
|
|
||||||
|
|
||||||
output = model_.invoke(messages)
|
output = model_.invoke(messages)
|
||||||
|
|
||||||
@@ -1093,10 +1093,8 @@ def create_agent( # noqa: PLR0915
|
|||||||
# Get the bound model (with auto-detection if needed)
|
# Get the bound model (with auto-detection if needed)
|
||||||
model_, effective_response_format = _get_bound_model(request)
|
model_, effective_response_format = _get_bound_model(request)
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
|
if request.system_prompt:
|
||||||
messages = [SystemMessage(content=request.system_prompt), *messages]
|
messages = [SystemMessage(request.system_prompt), *messages]
|
||||||
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
|
|
||||||
messages = [request.system_prompt, *messages]
|
|
||||||
|
|
||||||
output = await model_.ainvoke(messages)
|
output = await model_.ainvoke(messages)
|
||||||
|
|
||||||
|
|||||||
@@ -225,11 +225,9 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|||||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
return count_tokens_approximately(messages)
|
return count_tokens_approximately(messages)
|
||||||
else:
|
else:
|
||||||
system_msg = []
|
system_msg = (
|
||||||
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
|
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||||
system_msg = [SystemMessage(content=request.system_prompt)]
|
)
|
||||||
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
|
|
||||||
system_msg = [request.system_prompt]
|
|
||||||
|
|
||||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
return request.model.get_num_tokens_from_messages(
|
return request.model.get_num_tokens_from_messages(
|
||||||
@@ -255,12 +253,9 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|||||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
return count_tokens_approximately(messages)
|
return count_tokens_approximately(messages)
|
||||||
else:
|
else:
|
||||||
system_msg = []
|
system_msg = (
|
||||||
|
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||||
if request.system_prompt and not isinstance(request.system_prompt, SystemMessage):
|
)
|
||||||
system_msg = [SystemMessage(content=request.system_prompt)]
|
|
||||||
elif request.system_prompt and isinstance(request.system_prompt, SystemMessage):
|
|
||||||
system_msg = [request.system_prompt]
|
|
||||||
|
|
||||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||||
return request.model.get_num_tokens_from_messages(
|
return request.model.get_num_tokens_from_messages(
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Annotated, Literal
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from langchain_core.messages import SystemMessage, ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
@@ -199,22 +199,11 @@ class TodoListMiddleware(AgentMiddleware):
|
|||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
"""Update the system prompt to include the todo system prompt."""
|
"""Update the system prompt to include the todo system prompt."""
|
||||||
if request.system_prompt is None:
|
request.system_prompt = (
|
||||||
request.system_prompt = self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
elif isinstance(request.system_prompt, str):
|
if request.system_prompt
|
||||||
request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt
|
else self.system_prompt
|
||||||
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
|
)
|
||||||
request.system_prompt.content, str
|
|
||||||
):
|
|
||||||
request.system_prompt = SystemMessage(
|
|
||||||
content=request.system_prompt.content + self.system_prompt
|
|
||||||
)
|
|
||||||
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
|
|
||||||
request.system_prompt.content, list
|
|
||||||
):
|
|
||||||
request.system_prompt = SystemMessage(
|
|
||||||
content=[*request.system_prompt.content, self.system_prompt]
|
|
||||||
)
|
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
async def awrap_model_call(
|
async def awrap_model_call(
|
||||||
@@ -223,20 +212,9 @@ class TodoListMiddleware(AgentMiddleware):
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||||
if request.system_prompt is None:
|
request.system_prompt = (
|
||||||
request.system_prompt = self.system_prompt
|
request.system_prompt + "\n\n" + self.system_prompt
|
||||||
elif isinstance(request.system_prompt, str):
|
if request.system_prompt
|
||||||
request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt
|
else self.system_prompt
|
||||||
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
|
)
|
||||||
request.system_prompt.content, str
|
|
||||||
):
|
|
||||||
request.system_prompt = SystemMessage(
|
|
||||||
content=request.system_prompt.content + self.system_prompt
|
|
||||||
)
|
|
||||||
elif isinstance(request.system_prompt, SystemMessage) and isinstance(
|
|
||||||
request.system_prompt.content, list
|
|
||||||
):
|
|
||||||
request.system_prompt = SystemMessage(
|
|
||||||
content=[*request.system_prompt.content, self.system_prompt]
|
|
||||||
)
|
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from langchain_core.messages import ( # noqa: TC002
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
SystemMessage,
|
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||||
@@ -86,7 +85,7 @@ class ModelRequest:
|
|||||||
"""Model request information for the agent."""
|
"""Model request information for the agent."""
|
||||||
|
|
||||||
model: BaseChatModel
|
model: BaseChatModel
|
||||||
system_prompt: str | SystemMessage | None
|
system_prompt: str | None
|
||||||
messages: list[AnyMessage] # excluding system prompt
|
messages: list[AnyMessage] # excluding system prompt
|
||||||
tool_choice: Any | None
|
tool_choice: Any | None
|
||||||
tools: list[BaseTool | dict]
|
tools: list[BaseTool | dict]
|
||||||
@@ -104,7 +103,7 @@ class ModelRequest:
|
|||||||
Args:
|
Args:
|
||||||
**overrides: Keyword arguments for attributes to override. Supported keys:
|
**overrides: Keyword arguments for attributes to override. Supported keys:
|
||||||
- model: BaseChatModel instance
|
- model: BaseChatModel instance
|
||||||
- system_prompt: Optional system prompt string or SystemMessage object
|
- system_prompt: Optional system prompt string
|
||||||
- messages: List of messages
|
- messages: List of messages
|
||||||
- tool_choice: Tool choice configuration
|
- tool_choice: Tool choice configuration
|
||||||
- tools: List of available tools
|
- tools: List of available tools
|
||||||
@@ -1257,7 +1256,7 @@ def dynamic_prompt(
|
|||||||
request: ModelRequest,
|
request: ModelRequest,
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
prompt = cast("str | SystemMessage", func(request))
|
prompt = cast("str", func(request))
|
||||||
request.system_prompt = prompt
|
request.system_prompt = prompt
|
||||||
return handler(request)
|
return handler(request)
|
||||||
|
|
||||||
@@ -1267,7 +1266,7 @@ def dynamic_prompt(
|
|||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
) -> ModelCallResult:
|
) -> ModelCallResult:
|
||||||
# Delegate to sync function
|
# Delegate to sync function
|
||||||
prompt = cast("str | SystemMessage", func(request))
|
prompt = cast("str", func(request))
|
||||||
request.system_prompt = prompt
|
request.system_prompt = prompt
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from langchain_core.language_models.fake_chat_models import FakeChatModel
|
|||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
SystemMessage,
|
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
@@ -400,126 +399,3 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
|||||||
|
|
||||||
assert isinstance(calc_tool, ToolMessage)
|
assert isinstance(calc_tool, ToolMessage)
|
||||||
assert calc_tool.content == "[cleared]"
|
assert calc_tool.content == "[cleared]"
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# SystemMessage Tests
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_handles_system_message_prompt() -> None:
|
|
||||||
"""Test that middleware handles SystemMessage as system_prompt correctly."""
|
|
||||||
tool_call_id = "call-1"
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
|
||||||
)
|
|
||||||
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content="You are a helpful assistant.")
|
|
||||||
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
|
|
||||||
# Manually set SystemMessage as system_prompt
|
|
||||||
request.system_prompt = system_prompt
|
|
||||||
|
|
||||||
middleware = ContextEditingMiddleware(
|
|
||||||
edits=[ClearToolUsesEdit(trigger=50)],
|
|
||||||
token_count_method="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response")
|
|
||||||
|
|
||||||
# Call wrap_model_call - should not fail with SystemMessage
|
|
||||||
middleware.wrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# Request should have processed without errors
|
|
||||||
assert request.system_prompt == system_prompt
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
|
|
||||||
|
|
||||||
def test_does_not_double_wrap_system_message() -> None:
|
|
||||||
"""Test that middleware doesn't wrap SystemMessage in another SystemMessage."""
|
|
||||||
tool_call_id = "call-1"
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
|
||||||
)
|
|
||||||
tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id)
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content="Original system prompt")
|
|
||||||
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
|
|
||||||
request.system_prompt = system_prompt
|
|
||||||
|
|
||||||
middleware = ContextEditingMiddleware(
|
|
||||||
edits=[ClearToolUsesEdit(trigger=50)],
|
|
||||||
token_count_method="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response")
|
|
||||||
|
|
||||||
middleware.wrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# System prompt should still be the same SystemMessage, not wrapped
|
|
||||||
assert request.system_prompt == system_prompt
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
assert request.system_prompt.content == "Original system prompt"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_handles_system_message_prompt_async() -> None:
|
|
||||||
"""Test async version - middleware handles SystemMessage as system_prompt correctly."""
|
|
||||||
tool_call_id = "call-1"
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
|
||||||
)
|
|
||||||
tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id)
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content="You are a helpful assistant.")
|
|
||||||
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
|
|
||||||
# Manually set SystemMessage as system_prompt
|
|
||||||
request.system_prompt = system_prompt
|
|
||||||
|
|
||||||
middleware = ContextEditingMiddleware(
|
|
||||||
edits=[ClearToolUsesEdit(trigger=50)],
|
|
||||||
token_count_method="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response")
|
|
||||||
|
|
||||||
# Call awrap_model_call - should not fail with SystemMessage
|
|
||||||
await middleware.awrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# Request should have processed without errors
|
|
||||||
assert request.system_prompt == system_prompt
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_does_not_double_wrap_system_message_async() -> None:
|
|
||||||
"""Test async version - middleware doesn't wrap SystemMessage in another SystemMessage."""
|
|
||||||
tool_call_id = "call-1"
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
|
||||||
)
|
|
||||||
tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id)
|
|
||||||
|
|
||||||
system_prompt = SystemMessage(content="Original system prompt")
|
|
||||||
state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None)
|
|
||||||
request.system_prompt = system_prompt
|
|
||||||
|
|
||||||
middleware = ContextEditingMiddleware(
|
|
||||||
edits=[ClearToolUsesEdit(trigger=50)],
|
|
||||||
token_count_method="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_handler(req: ModelRequest) -> AIMessage:
|
|
||||||
return AIMessage(content="mock response")
|
|
||||||
|
|
||||||
await middleware.awrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# System prompt should still be the same SystemMessage, not wrapped
|
|
||||||
assert request.system_prompt == system_prompt
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
assert request.system_prompt.content == "Original system prompt"
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||||
from langchain_core.messages import AIMessage, SystemMessage
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
from langchain.agents.middleware.todo import TodoListMiddleware
|
from langchain.agents.middleware.todo import TodoListMiddleware
|
||||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||||
@@ -170,131 +170,3 @@ async def test_handler_called_with_modified_request_async() -> None:
|
|||||||
assert received_prompt["value"] is not None
|
assert received_prompt["value"] is not None
|
||||||
assert "Original" in received_prompt["value"]
|
assert "Original" in received_prompt["value"]
|
||||||
assert "write_todos" in received_prompt["value"]
|
assert "write_todos" in received_prompt["value"]
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# SystemMessage Tests
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_appends_to_system_message_with_list_content() -> None:
|
|
||||||
"""Test that middleware appends to SystemMessage with list content."""
|
|
||||||
existing_prompt = SystemMessage(content=["You are helpful.", "Be concise."])
|
|
||||||
middleware = TodoListMiddleware()
|
|
||||||
|
|
||||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
|
||||||
request = ModelRequest(
|
|
||||||
model=model,
|
|
||||||
system_prompt=existing_prompt,
|
|
||||||
messages=[],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
|
||||||
runtime=_fake_runtime(),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
|
||||||
return ModelResponse(result=[AIMessage(content="response")])
|
|
||||||
|
|
||||||
middleware.wrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# System prompt should be a SystemMessage with combined content
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
assert isinstance(request.system_prompt.content, list)
|
|
||||||
assert len(request.system_prompt.content) == 3
|
|
||||||
assert request.system_prompt.content[0] == "You are helpful."
|
|
||||||
assert request.system_prompt.content[1] == "Be concise."
|
|
||||||
assert "write_todos" in request.system_prompt.content[2]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_appends_to_system_message_with_string_content_async() -> None:
|
|
||||||
"""Test async version - middleware appends to SystemMessage with string content."""
|
|
||||||
existing_prompt = SystemMessage(content="You are a helpful assistant.")
|
|
||||||
middleware = TodoListMiddleware()
|
|
||||||
|
|
||||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
|
||||||
request = ModelRequest(
|
|
||||||
model=model,
|
|
||||||
system_prompt=existing_prompt,
|
|
||||||
messages=[],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
|
||||||
runtime=_fake_runtime(),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
|
||||||
return ModelResponse(result=[AIMessage(content="response")])
|
|
||||||
|
|
||||||
await middleware.awrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# System prompt should be a SystemMessage with combined content
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
assert isinstance(request.system_prompt.content, str)
|
|
||||||
assert "You are a helpful assistant." in request.system_prompt.content
|
|
||||||
assert middleware.system_prompt in request.system_prompt.content
|
|
||||||
|
|
||||||
|
|
||||||
async def test_appends_to_system_message_with_list_content_async() -> None:
|
|
||||||
"""Test async version - middleware appends to SystemMessage with list content."""
|
|
||||||
existing_prompt = SystemMessage(content=["You are helpful.", "Be concise."])
|
|
||||||
middleware = TodoListMiddleware()
|
|
||||||
|
|
||||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
|
||||||
request = ModelRequest(
|
|
||||||
model=model,
|
|
||||||
system_prompt=existing_prompt,
|
|
||||||
messages=[],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
|
||||||
runtime=_fake_runtime(),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
|
||||||
return ModelResponse(result=[AIMessage(content="response")])
|
|
||||||
|
|
||||||
await middleware.awrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# System prompt should be a SystemMessage with combined content
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
assert isinstance(request.system_prompt.content, list)
|
|
||||||
assert len(request.system_prompt.content) == 3
|
|
||||||
assert request.system_prompt.content[0] == "You are helpful."
|
|
||||||
assert request.system_prompt.content[1] == "Be concise."
|
|
||||||
assert "write_todos" in request.system_prompt.content[2]
|
|
||||||
|
|
||||||
|
|
||||||
def test_creates_system_message_when_prompt_is_system_message() -> None:
|
|
||||||
"""Test that middleware preserves SystemMessage type."""
|
|
||||||
existing_prompt = SystemMessage(content="Original instructions")
|
|
||||||
middleware = TodoListMiddleware(system_prompt="Todo instructions")
|
|
||||||
|
|
||||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="response")]))
|
|
||||||
request = ModelRequest(
|
|
||||||
model=model,
|
|
||||||
system_prompt=existing_prompt,
|
|
||||||
messages=[],
|
|
||||||
tool_choice=None,
|
|
||||||
tools=[],
|
|
||||||
response_format=None,
|
|
||||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
|
||||||
runtime=_fake_runtime(),
|
|
||||||
model_settings={},
|
|
||||||
)
|
|
||||||
|
|
||||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
|
||||||
return ModelResponse(result=[AIMessage(content="response")])
|
|
||||||
|
|
||||||
result = middleware.wrap_model_call(request, mock_handler)
|
|
||||||
|
|
||||||
# Result should be returned from handler
|
|
||||||
assert isinstance(result, ModelResponse)
|
|
||||||
# System prompt should still be a SystemMessage
|
|
||||||
assert isinstance(request.system_prompt, SystemMessage)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user