diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 9bc18ea943b..c8880b65cd7 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -516,7 +516,7 @@ def create_agent( # noqa: PLR0915 model: str | BaseChatModel, tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, *, - system_prompt: str | None = None, + system_prompt: str | SystemMessage | None = None, middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None, state_schema: type[AgentState[ResponseT]] | None = None, @@ -548,11 +548,9 @@ def create_agent( # noqa: PLR0915 If `None` or an empty list, the agent will consist of a model node without a tool calling loop. - system_prompt: An optional system prompt for the LLM. + system_prompt: An optional system prompt for the LLM or + can already be a [`SystemMessage`][langchain.messages.SystemMessage] object. - Prompts are converted to a - [`SystemMessage`][langchain.messages.SystemMessage] and added to the - beginning of the message list. middleware: A sequence of middleware instances to apply to the agent. Middleware can intercept and modify agent behavior at various stages. See @@ -1040,8 +1038,10 @@ def create_agent( # noqa: PLR0915 # Get the bound model (with auto-detection if needed) model_, effective_response_format = _get_bound_model(request) messages = request.messages - if request.system_prompt: - messages = [SystemMessage(request.system_prompt), *messages] + if request.system_prompt and not isinstance(request.system_prompt, SystemMessage): + messages = [SystemMessage(content=request.system_prompt), *messages] + elif request.system_prompt and isinstance(request.system_prompt, SystemMessage): + messages = [request.system_prompt, *messages] output = model_.invoke(messages) @@ -1093,8 +1093,10 @@ def create_agent( # noqa: PLR0915 # Get the bound model (with auto-detection if needed) model_, effective_response_format = _get_bound_model(request) messages = request.messages - if request.system_prompt: - messages = [SystemMessage(request.system_prompt), *messages] + if request.system_prompt and not isinstance(request.system_prompt, SystemMessage): + messages = [SystemMessage(content=request.system_prompt), *messages] + elif request.system_prompt and isinstance(request.system_prompt, SystemMessage): + messages = [request.system_prompt, *messages] output = await model_.ainvoke(messages) diff --git a/libs/langchain_v1/langchain/agents/middleware/context_editing.py b/libs/langchain_v1/langchain/agents/middleware/context_editing.py index 3623d1872ce..3de8a380f7e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/context_editing.py +++ b/libs/langchain_v1/langchain/agents/middleware/context_editing.py @@ -225,9 +225,11 @@ class ContextEditingMiddleware(AgentMiddleware): def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: - system_msg = ( - [SystemMessage(content=request.system_prompt)] if request.system_prompt else [] - ) + system_msg = [] + if request.system_prompt and not isinstance(request.system_prompt, SystemMessage): + system_msg = [SystemMessage(content=request.system_prompt)] + elif request.system_prompt and isinstance(request.system_prompt, SystemMessage): + system_msg = [request.system_prompt] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( @@ -253,9 +255,12 @@ class ContextEditingMiddleware(AgentMiddleware): def count_tokens(messages: Sequence[BaseMessage]) -> int: return count_tokens_approximately(messages) else: - system_msg = ( - [SystemMessage(content=request.system_prompt)] if request.system_prompt else [] - ) + system_msg = [] + + if request.system_prompt and not isinstance(request.system_prompt, SystemMessage): + system_msg = [SystemMessage(content=request.system_prompt)] + elif request.system_prompt and isinstance(request.system_prompt, SystemMessage): + system_msg = [request.system_prompt] def count_tokens(messages: Sequence[BaseMessage]) -> int: return request.model.get_num_tokens_from_messages( diff --git a/libs/langchain_v1/langchain/agents/middleware/todo.py b/libs/langchain_v1/langchain/agents/middleware/todo.py index 5ec0215efc1..c2b1b75d05c 100644 --- a/libs/langchain_v1/langchain/agents/middleware/todo.py +++ b/libs/langchain_v1/langchain/agents/middleware/todo.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Annotated, Literal if TYPE_CHECKING: from collections.abc import Awaitable, Callable -from langchain_core.messages import ToolMessage +from langchain_core.messages import SystemMessage, ToolMessage from langchain_core.tools import tool from langgraph.types import Command from typing_extensions import NotRequired, TypedDict @@ -199,11 +199,22 @@ class TodoListMiddleware(AgentMiddleware): handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: """Update the system prompt to include the todo system prompt.""" - request.system_prompt = ( - request.system_prompt + "\n\n" + self.system_prompt - if request.system_prompt - else self.system_prompt - ) + if request.system_prompt is None: + request.system_prompt = self.system_prompt + elif isinstance(request.system_prompt, str): + request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt + elif isinstance(request.system_prompt, SystemMessage) and isinstance( + request.system_prompt.content, str + ): + request.system_prompt = SystemMessage( + content=request.system_prompt.content + self.system_prompt + ) + elif isinstance(request.system_prompt, SystemMessage) and isinstance( + request.system_prompt.content, list + ): + request.system_prompt = SystemMessage( + content=[*request.system_prompt.content, self.system_prompt] + ) return handler(request) async def awrap_model_call( @@ -212,9 +223,20 @@ class TodoListMiddleware(AgentMiddleware): handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: """Update the system prompt to include the todo system prompt (async version).""" - request.system_prompt = ( - request.system_prompt + "\n\n" + self.system_prompt - if request.system_prompt - else self.system_prompt - ) + if request.system_prompt is None: + request.system_prompt = self.system_prompt + elif isinstance(request.system_prompt, str): + request.system_prompt = request.system_prompt + "\n\n" + self.system_prompt + elif isinstance(request.system_prompt, SystemMessage) and isinstance( + request.system_prompt.content, str + ): + request.system_prompt = SystemMessage( + content=request.system_prompt.content + self.system_prompt + ) + elif isinstance(request.system_prompt, SystemMessage) and isinstance( + request.system_prompt.content, list + ): + request.system_prompt = SystemMessage( + content=[*request.system_prompt.content, self.system_prompt] + ) return await handler(request) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index cedcc285020..41f16e8f974 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -26,6 +26,7 @@ from langchain_core.messages import ( # noqa: TC002 AIMessage, AnyMessage, BaseMessage, + SystemMessage, ToolMessage, ) from langgraph.channels.ephemeral_value import EphemeralValue @@ -85,7 +86,7 @@ class ModelRequest: """Model request information for the agent.""" model: BaseChatModel - system_prompt: str | None + system_prompt: str | SystemMessage | None messages: list[AnyMessage] # excluding system prompt tool_choice: Any | None tools: list[BaseTool | dict] @@ -103,7 +104,7 @@ class ModelRequest: Args: **overrides: Keyword arguments for attributes to override. Supported keys: - model: BaseChatModel instance - - system_prompt: Optional system prompt string + - system_prompt: Optional system prompt string or SystemMessage object - messages: List of messages - tool_choice: Tool choice configuration - tools: List of available tools @@ -1256,7 +1257,7 @@ def dynamic_prompt( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: - prompt = cast("str", func(request)) + prompt = cast("str | SystemMessage", func(request)) request.system_prompt = prompt return handler(request) @@ -1266,7 +1267,7 @@ def dynamic_prompt( handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: # Delegate to sync function - prompt = cast("str", func(request)) + prompt = cast("str | SystemMessage", func(request)) request.system_prompt = prompt return await handler(request) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_context_editing_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/test_context_editing_middleware.py index 7a9d901b951..33a05730ce0 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_context_editing_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_context_editing_middleware.py @@ -13,6 +13,7 @@ from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( AIMessage, MessageLikeRepresentation, + SystemMessage, ToolMessage, ) from langgraph.runtime import Runtime @@ -399,3 +400,126 @@ async def test_exclude_tools_prevents_clearing_async() -> None: assert isinstance(calc_tool, ToolMessage) assert calc_tool.content == "[cleared]" + + +# ============================================================================== +# SystemMessage Tests +# ============================================================================== + + +def test_handles_system_message_prompt() -> None: + """Test that middleware handles SystemMessage as system_prompt correctly.""" + tool_call_id = "call-1" + ai_message = AIMessage( + content="", + tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}], + ) + tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id) + + system_prompt = SystemMessage(content="You are a helpful assistant.") + state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None) + # Manually set SystemMessage as system_prompt + request.system_prompt = system_prompt + + middleware = ContextEditingMiddleware( + edits=[ClearToolUsesEdit(trigger=50)], + token_count_method="model", + ) + + def mock_handler(req: ModelRequest) -> AIMessage: + return AIMessage(content="mock response") + + # Call wrap_model_call - should not fail with SystemMessage + middleware.wrap_model_call(request, mock_handler) + + # Request should have processed without errors + assert request.system_prompt == system_prompt + assert isinstance(request.system_prompt, SystemMessage) + + +def test_does_not_double_wrap_system_message() -> None: + """Test that middleware doesn't wrap SystemMessage in another SystemMessage.""" + tool_call_id = "call-1" + ai_message = AIMessage( + content="", + tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}], + ) + tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id) + + system_prompt = SystemMessage(content="Original system prompt") + state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None) + request.system_prompt = system_prompt + + middleware = ContextEditingMiddleware( + edits=[ClearToolUsesEdit(trigger=50)], + token_count_method="model", + ) + + def mock_handler(req: ModelRequest) -> AIMessage: + return AIMessage(content="mock response") + + middleware.wrap_model_call(request, mock_handler) + + # System prompt should still be the same SystemMessage, not wrapped + assert request.system_prompt == system_prompt + assert isinstance(request.system_prompt, SystemMessage) + assert request.system_prompt.content == "Original system prompt" + + +async def test_handles_system_message_prompt_async() -> None: + """Test async version - middleware handles SystemMessage as system_prompt correctly.""" + tool_call_id = "call-1" + ai_message = AIMessage( + content="", + tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}], + ) + tool_message = ToolMessage(content="12345", tool_call_id=tool_call_id) + + system_prompt = SystemMessage(content="You are a helpful assistant.") + state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None) + # Manually set SystemMessage as system_prompt + request.system_prompt = system_prompt + + middleware = ContextEditingMiddleware( + edits=[ClearToolUsesEdit(trigger=50)], + token_count_method="model", + ) + + async def mock_handler(req: ModelRequest) -> AIMessage: + return AIMessage(content="mock response") + + # Call awrap_model_call - should not fail with SystemMessage + await middleware.awrap_model_call(request, mock_handler) + + # Request should have processed without errors + assert request.system_prompt == system_prompt + assert isinstance(request.system_prompt, SystemMessage) + + +async def test_does_not_double_wrap_system_message_async() -> None: + """Test async version - middleware doesn't wrap SystemMessage in another SystemMessage.""" + tool_call_id = "call-1" + ai_message = AIMessage( + content="", + tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}], + ) + tool_message = ToolMessage(content="x" * 100, tool_call_id=tool_call_id) + + system_prompt = SystemMessage(content="Original system prompt") + state, request = _make_state_and_request([ai_message, tool_message], system_prompt=None) + request.system_prompt = system_prompt + + middleware = ContextEditingMiddleware( + edits=[ClearToolUsesEdit(trigger=50)], + token_count_method="model", + ) + + async def mock_handler(req: ModelRequest) -> AIMessage: + return AIMessage(content="mock response") + + await middleware.awrap_model_call(request, mock_handler) + + # System prompt should still be the same SystemMessage, not wrapped + assert request.system_prompt == system_prompt + assert isinstance(request.system_prompt, SystemMessage) + assert request.system_prompt.content == "Original system prompt" diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_todo_middleware.py b/libs/langchain_v1/tests/unit_tests/agents/test_todo_middleware.py index 5f96855d4da..bded1390f79 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_todo_middleware.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_todo_middleware.py @@ -5,7 +5,7 @@ from __future__ import annotations from typing import cast from langchain_core.language_models.fake_chat_models import GenericFakeChatModel -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, SystemMessage from langchain.agents.middleware.todo import TodoListMiddleware from langchain.agents.middleware.types import ModelRequest, ModelResponse @@ -170,3 +170,131 @@ async def test_handler_called_with_modified_request_async() -> None: assert received_prompt["value"] is not None assert "Original" in received_prompt["value"] assert "write_todos" in received_prompt["value"] + + +# ============================================================================== +# SystemMessage Tests +# ============================================================================== + + +def test_appends_to_system_message_with_list_content() -> None: + """Test that middleware appends to SystemMessage with list content.""" + existing_prompt = SystemMessage(content=["You are helpful.", "Be concise."]) + middleware = TodoListMiddleware() + + model = GenericFakeChatModel(messages=iter([AIMessage(content="response")])) + request = ModelRequest( + model=model, + system_prompt=existing_prompt, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state=cast("AgentState", {}), # type: ignore[name-defined] + runtime=_fake_runtime(), + model_settings={}, + ) + + def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse(result=[AIMessage(content="response")]) + + middleware.wrap_model_call(request, mock_handler) + + # System prompt should be a SystemMessage with combined content + assert isinstance(request.system_prompt, SystemMessage) + assert isinstance(request.system_prompt.content, list) + assert len(request.system_prompt.content) == 3 + assert request.system_prompt.content[0] == "You are helpful." + assert request.system_prompt.content[1] == "Be concise." + assert "write_todos" in request.system_prompt.content[2] + + +async def test_appends_to_system_message_with_string_content_async() -> None: + """Test async version - middleware appends to SystemMessage with string content.""" + existing_prompt = SystemMessage(content="You are a helpful assistant.") + middleware = TodoListMiddleware() + + model = GenericFakeChatModel(messages=iter([AIMessage(content="response")])) + request = ModelRequest( + model=model, + system_prompt=existing_prompt, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state=cast("AgentState", {}), # type: ignore[name-defined] + runtime=_fake_runtime(), + model_settings={}, + ) + + async def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse(result=[AIMessage(content="response")]) + + await middleware.awrap_model_call(request, mock_handler) + + # System prompt should be a SystemMessage with combined content + assert isinstance(request.system_prompt, SystemMessage) + assert isinstance(request.system_prompt.content, str) + assert "You are a helpful assistant." in request.system_prompt.content + assert middleware.system_prompt in request.system_prompt.content + + +async def test_appends_to_system_message_with_list_content_async() -> None: + """Test async version - middleware appends to SystemMessage with list content.""" + existing_prompt = SystemMessage(content=["You are helpful.", "Be concise."]) + middleware = TodoListMiddleware() + + model = GenericFakeChatModel(messages=iter([AIMessage(content="response")])) + request = ModelRequest( + model=model, + system_prompt=existing_prompt, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state=cast("AgentState", {}), # type: ignore[name-defined] + runtime=_fake_runtime(), + model_settings={}, + ) + + async def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse(result=[AIMessage(content="response")]) + + await middleware.awrap_model_call(request, mock_handler) + + # System prompt should be a SystemMessage with combined content + assert isinstance(request.system_prompt, SystemMessage) + assert isinstance(request.system_prompt.content, list) + assert len(request.system_prompt.content) == 3 + assert request.system_prompt.content[0] == "You are helpful." + assert request.system_prompt.content[1] == "Be concise." + assert "write_todos" in request.system_prompt.content[2] + + +def test_creates_system_message_when_prompt_is_system_message() -> None: + """Test that middleware preserves SystemMessage type.""" + existing_prompt = SystemMessage(content="Original instructions") + middleware = TodoListMiddleware(system_prompt="Todo instructions") + + model = GenericFakeChatModel(messages=iter([AIMessage(content="response")])) + request = ModelRequest( + model=model, + system_prompt=existing_prompt, + messages=[], + tool_choice=None, + tools=[], + response_format=None, + state=cast("AgentState", {}), # type: ignore[name-defined] + runtime=_fake_runtime(), + model_settings={}, + ) + + def mock_handler(req: ModelRequest) -> ModelResponse: + return ModelResponse(result=[AIMessage(content="response")]) + + result = middleware.wrap_model_call(request, mock_handler) + + # Result should be returned from handler + assert isinstance(result, ModelResponse) + # System prompt should still be a SystemMessage + assert isinstance(request.system_prompt, SystemMessage)