Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
2ed8de4f3a first pass system message 2025-11-19 13:35:29 -05:00
6 changed files with 199 additions and 49 deletions

View File

@@ -545,7 +545,7 @@ def create_agent( # noqa: PLR0915
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
*,
system_prompt: str | None = None,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
state_schema: type[AgentState[ResponseT]] | None = None,
@@ -591,9 +591,9 @@ def create_agent( # noqa: PLR0915
docs for more information.
system_prompt: An optional system prompt for the LLM.
Prompts are converted to a
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
beginning of the message list.
Can be either a string or a [`SystemMessage`][langchain.messages.SystemMessage].
String prompts are converted to a `SystemMessage` and added to the beginning
of the message list.
middleware: A sequence of middleware instances to apply to the agent.
Middleware can intercept and modify agent behavior at various stages.
@@ -688,6 +688,13 @@ def create_agent( # noqa: PLR0915
if isinstance(model, str):
model = init_chat_model(model)
# Convert system_prompt to SystemMessage if it's a string
normalized_system_prompt: SystemMessage | None = None
if isinstance(system_prompt, str):
normalized_system_prompt = SystemMessage(content=system_prompt)
elif isinstance(system_prompt, SystemMessage):
normalized_system_prompt = system_prompt
# Handle tools being None or empty
if tools is None:
tools = []
@@ -1092,7 +1099,7 @@ def create_agent( # noqa: PLR0915
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
messages = [request.system_prompt, *messages]
output = model_.invoke(messages)
@@ -1111,7 +1118,7 @@ def create_agent( # noqa: PLR0915
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
system_prompt=normalized_system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
@@ -1145,7 +1152,7 @@ def create_agent( # noqa: PLR0915
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
messages = [request.system_prompt, *messages]
output = await model_.ainvoke(messages)
@@ -1164,7 +1171,7 @@ def create_agent( # noqa: PLR0915
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
system_prompt=normalized_system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,

View File

@@ -230,9 +230,7 @@ class ContextEditingMiddleware(AgentMiddleware):
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
)
system_msg = [request.system_prompt] if request.system_prompt else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(
@@ -259,9 +257,7 @@ class ContextEditingMiddleware(AgentMiddleware):
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)
else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
)
system_msg = [request.system_prompt] if request.system_prompt else []
def count_tokens(messages: Sequence[BaseMessage]) -> int:
return request.model.get_num_tokens_from_messages(

View File

@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Annotated, Literal
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from langchain_core.messages import ToolMessage
from langchain_core.messages import SystemMessage, ToolMessage
from langchain_core.tools import tool
from langgraph.types import Command
from typing_extensions import NotRequired, TypedDict
@@ -194,11 +194,11 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt."""
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
if request.system_prompt:
new_content = request.system_prompt.content + "\n\n" + self.system_prompt
new_system_prompt = SystemMessage(content=new_content)
else:
new_system_prompt = SystemMessage(content=self.system_prompt)
return handler(request.override(system_prompt=new_system_prompt))
async def awrap_model_call(
@@ -207,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Update the system prompt to include the todo system prompt (async version)."""
new_system_prompt = (
request.system_prompt + "\n\n" + self.system_prompt
if request.system_prompt
else self.system_prompt
)
if request.system_prompt:
new_content = request.system_prompt.content + "\n\n" + self.system_prompt
new_system_prompt = SystemMessage(content=new_content)
else:
new_system_prompt = SystemMessage(content=self.system_prompt)
return await handler(request.override(system_prompt=new_system_prompt))

View File

@@ -72,7 +72,7 @@ class _ModelRequestOverrides(TypedDict, total=False):
"""Possible overrides for `ModelRequest.override()` method."""
model: BaseChatModel
system_prompt: str | None
system_prompt: str | BaseMessage | None
messages: list[AnyMessage]
tool_choice: Any | None
tools: list[BaseTool | dict]
@@ -85,7 +85,7 @@ class ModelRequest:
"""Model request information for the agent."""
model: BaseChatModel
system_prompt: str | None
system_prompt: BaseMessage | None
messages: list[AnyMessage] # excluding system prompt
tool_choice: Any | None
tools: list[BaseTool | dict]
@@ -94,6 +94,24 @@ class ModelRequest:
runtime: Runtime[ContextT] # type: ignore[valid-type]
model_settings: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate and coerce system_prompt to SystemMessage with deprecation warning."""
import warnings
from langchain_core.messages import SystemMessage
if isinstance(self.system_prompt, str):
warnings.warn(
"Passing a string for `system_prompt` in `ModelRequest` is deprecated. "
"Please use a `SystemMessage` instance instead. "
"String values will be automatically converted to `SystemMessage`, "
"but this behavior will be removed in a future version.",
DeprecationWarning,
stacklevel=3,
)
# Coerce string to SystemMessage for backward compatibility
object.__setattr__(self, "system_prompt", SystemMessage(content=self.system_prompt))
def __setattr__(self, name: str, value: Any) -> None:
"""Set an attribute with a deprecation warning.
@@ -132,7 +150,7 @@ class ModelRequest:
Supported keys:
- `model`: `BaseChatModel` instance
- `system_prompt`: Optional system prompt string
- `system_prompt`: Optional system prompt (`SystemMessage` or `str`)
- `messages`: `list` of messages
- `tool_choice`: Tool choice configuration
- `tools`: `list` of available tools

View File

@@ -95,7 +95,7 @@ def test_adds_system_prompt_when_none_exists() -> None:
# System prompt should be set in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert "write_todos" in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt.content
# Original request should be unchanged
assert request.system_prompt is None
@@ -118,11 +118,11 @@ def test_appends_to_existing_system_prompt() -> None:
# System prompt should contain both in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert existing_prompt in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt
assert captured_request.system_prompt.startswith(existing_prompt)
assert existing_prompt in captured_request.system_prompt.content
assert "write_todos" in captured_request.system_prompt.content
assert captured_request.system_prompt.content.startswith(existing_prompt)
# Original request should be unchanged
assert request.system_prompt == existing_prompt
assert request.system_prompt.content == existing_prompt
@pytest.mark.parametrize(
@@ -162,9 +162,12 @@ def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix)
middleware.wrap_model_call(request, mock_handler)
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt.startswith(expected_prompt_prefix)
# Original request should be unchanged
assert request.system_prompt == original_prompt
assert captured_request.system_prompt.content.startswith(expected_prompt_prefix)
# Original request should be unchanged (it's now a SystemMessage)
if original_prompt is None:
assert request.system_prompt is None
else:
assert request.system_prompt.content == original_prompt
def test_custom_system_prompt() -> None:
@@ -184,7 +187,7 @@ def test_custom_system_prompt() -> None:
# Should use custom prompt in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt == custom_prompt
assert captured_request.system_prompt.content == custom_prompt
# Original request should be unchanged
assert request.system_prompt is None
@@ -220,9 +223,9 @@ def test_todo_middleware_custom_system_prompt() -> None:
middleware.wrap_model_call(request, mock_handler)
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
# Original request should be unchanged
assert request.system_prompt == "Original prompt"
assert captured_request.system_prompt.content == f"Original prompt\n\n{custom_system_prompt}"
# Original request should be unchanged (it's now a SystemMessage)
assert request.system_prompt.content == "Original prompt"
def test_custom_tool_description() -> None:
@@ -281,7 +284,7 @@ def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
middleware.wrap_model_call(request, mock_handler)
# Check that the modified request passed to handler has the expected prompt
assert captured_request is not None
assert captured_request.system_prompt == custom_system_prompt
assert captured_request.system_prompt.content == custom_system_prompt
# Original request should be unchanged
assert request.system_prompt is None
@@ -444,7 +447,7 @@ async def test_adds_system_prompt_when_none_exists_async() -> None:
# System prompt should be set in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert "write_todos" in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt.content
# Original request should be unchanged
assert request.system_prompt is None
@@ -467,11 +470,11 @@ async def test_appends_to_existing_system_prompt_async() -> None:
# System prompt should contain both in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt is not None
assert existing_prompt in captured_request.system_prompt
assert "write_todos" in captured_request.system_prompt
assert captured_request.system_prompt.startswith(existing_prompt)
assert existing_prompt in captured_request.system_prompt.content
assert "write_todos" in captured_request.system_prompt.content
assert captured_request.system_prompt.content.startswith(existing_prompt)
# Original request should be unchanged
assert request.system_prompt == existing_prompt
assert request.system_prompt.content == existing_prompt
async def test_custom_system_prompt_async() -> None:
@@ -491,7 +494,7 @@ async def test_custom_system_prompt_async() -> None:
# Should use custom prompt in the modified request passed to handler
assert captured_request is not None
assert captured_request.system_prompt == custom_prompt
assert captured_request.system_prompt.content == custom_prompt
# Original request should be unchanged
assert request.system_prompt is None
@@ -512,5 +515,5 @@ async def test_handler_called_with_modified_request_async() -> None:
assert handler_called["value"]
assert received_prompt["value"] is not None
assert "Original" in received_prompt["value"]
assert "write_todos" in received_prompt["value"]
assert "Original" in received_prompt["value"].content
assert "write_todos" in received_prompt["value"].content

View File

@@ -0,0 +1,126 @@
"""Tests for system_prompt support in create_agent and ModelRequest."""
import warnings
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.agents import create_agent
from langchain.agents.middleware.types import ModelRequest
from tests.unit_tests.agents.model import FakeToolCallingModel
def test_create_agent_accepts_string_system_prompt():
"""Test that create_agent accepts a string system_prompt."""
model = FakeToolCallingModel()
agent = create_agent(model, system_prompt="You are a helpful assistant")
# Run the agent to ensure it works
result = agent.invoke({"messages": [HumanMessage(content="Hello")]})
assert "messages" in result
def test_create_agent_accepts_system_message():
"""Test that create_agent accepts a SystemMessage for system_prompt."""
model = FakeToolCallingModel()
system_msg = SystemMessage(content="You are a helpful assistant")
agent = create_agent(model, system_prompt=system_msg)
# Run the agent to ensure it works
result = agent.invoke({"messages": [HumanMessage(content="Hello")]})
assert "messages" in result
def test_model_request_deprecates_string_system_prompt(mock_runtime):
"""Test that ModelRequest raises deprecation warning for string system_prompt."""
model = FakeToolCallingModel()
# Expect a deprecation warning when passing a string
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
request = ModelRequest(
model=model,
system_prompt="You are a helpful assistant",
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state={"messages": []},
runtime=mock_runtime,
)
# Check that a deprecation warning was raised
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "system_prompt" in str(w[0].message)
# Verify that the string was coerced to SystemMessage
assert isinstance(request.system_prompt, SystemMessage)
assert request.system_prompt.content == "You are a helpful assistant"
def test_model_request_accepts_system_message(mock_runtime):
"""Test that ModelRequest accepts SystemMessage without deprecation warning."""
model = FakeToolCallingModel()
system_msg = SystemMessage(content="You are a helpful assistant")
# Should not raise any warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
request = ModelRequest(
model=model,
system_prompt=system_msg,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state={"messages": []},
runtime=mock_runtime,
)
# No deprecation warning should be raised
assert len(w) == 0
# Verify that the SystemMessage is preserved
assert isinstance(request.system_prompt, SystemMessage)
assert request.system_prompt.content == "You are a helpful assistant"
def test_model_request_override_with_string(mock_runtime):
"""Test that ModelRequest.override() works with string system_prompt."""
model = FakeToolCallingModel()
system_msg = SystemMessage(content="Original prompt")
request = ModelRequest(
model=model,
system_prompt=system_msg,
messages=[],
tool_choice=None,
tools=[],
response_format=None,
state={"messages": []},
runtime=mock_runtime,
)
# Override with a string - should trigger deprecation warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
new_request = request.override(system_prompt="New prompt")
# Check for deprecation warning
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
# Verify the override worked
assert isinstance(new_request.system_prompt, SystemMessage)
assert new_request.system_prompt.content == "New prompt"
@pytest.fixture
def mock_runtime():
"""Create a mock runtime for testing."""
class MockRuntime:
pass
return MockRuntime()