mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 12:11:34 +00:00
Compare commits
1 Commits
langchain-
...
sr/system-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ed8de4f3a |
@@ -545,7 +545,7 @@ def create_agent( # noqa: PLR0915
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
||||
*,
|
||||
system_prompt: str | None = None,
|
||||
system_prompt: str | SystemMessage | None = None,
|
||||
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
||||
state_schema: type[AgentState[ResponseT]] | None = None,
|
||||
@@ -591,9 +591,9 @@ def create_agent( # noqa: PLR0915
|
||||
docs for more information.
|
||||
system_prompt: An optional system prompt for the LLM.
|
||||
|
||||
Prompts are converted to a
|
||||
[`SystemMessage`][langchain.messages.SystemMessage] and added to the
|
||||
beginning of the message list.
|
||||
Can be either a string or a [`SystemMessage`][langchain.messages.SystemMessage].
|
||||
String prompts are converted to a `SystemMessage` and added to the beginning
|
||||
of the message list.
|
||||
middleware: A sequence of middleware instances to apply to the agent.
|
||||
|
||||
Middleware can intercept and modify agent behavior at various stages.
|
||||
@@ -688,6 +688,13 @@ def create_agent( # noqa: PLR0915
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
# Convert system_prompt to SystemMessage if it's a string
|
||||
normalized_system_prompt: SystemMessage | None = None
|
||||
if isinstance(system_prompt, str):
|
||||
normalized_system_prompt = SystemMessage(content=system_prompt)
|
||||
elif isinstance(system_prompt, SystemMessage):
|
||||
normalized_system_prompt = system_prompt
|
||||
|
||||
# Handle tools being None or empty
|
||||
if tools is None:
|
||||
tools = []
|
||||
@@ -1092,7 +1099,7 @@ def create_agent( # noqa: PLR0915
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
messages = [request.system_prompt, *messages]
|
||||
|
||||
output = model_.invoke(messages)
|
||||
|
||||
@@ -1111,7 +1118,7 @@ def create_agent( # noqa: PLR0915
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=normalized_system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
@@ -1145,7 +1152,7 @@ def create_agent( # noqa: PLR0915
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
messages = [request.system_prompt, *messages]
|
||||
|
||||
output = await model_.ainvoke(messages)
|
||||
|
||||
@@ -1164,7 +1171,7 @@ def create_agent( # noqa: PLR0915
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=normalized_system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
|
||||
@@ -230,9 +230,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
system_msg = [request.system_prompt] if request.system_prompt else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
@@ -259,9 +257,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
system_msg = [request.system_prompt] if request.system_prompt else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Annotated, Literal
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.messages import SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@@ -194,11 +194,11 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt."""
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
if request.system_prompt:
|
||||
new_content = request.system_prompt.content + "\n\n" + self.system_prompt
|
||||
new_system_prompt = SystemMessage(content=new_content)
|
||||
else:
|
||||
new_system_prompt = SystemMessage(content=self.system_prompt)
|
||||
return handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
async def awrap_model_call(
|
||||
@@ -207,9 +207,9 @@ class TodoListMiddleware(AgentMiddleware):
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Update the system prompt to include the todo system prompt (async version)."""
|
||||
new_system_prompt = (
|
||||
request.system_prompt + "\n\n" + self.system_prompt
|
||||
if request.system_prompt
|
||||
else self.system_prompt
|
||||
)
|
||||
if request.system_prompt:
|
||||
new_content = request.system_prompt.content + "\n\n" + self.system_prompt
|
||||
new_system_prompt = SystemMessage(content=new_content)
|
||||
else:
|
||||
new_system_prompt = SystemMessage(content=self.system_prompt)
|
||||
return await handler(request.override(system_prompt=new_system_prompt))
|
||||
|
||||
@@ -72,7 +72,7 @@ class _ModelRequestOverrides(TypedDict, total=False):
|
||||
"""Possible overrides for `ModelRequest.override()` method."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
system_prompt: str | BaseMessage | None
|
||||
messages: list[AnyMessage]
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
@@ -85,7 +85,7 @@ class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
|
||||
model: BaseChatModel
|
||||
system_prompt: str | None
|
||||
system_prompt: BaseMessage | None
|
||||
messages: list[AnyMessage] # excluding system prompt
|
||||
tool_choice: Any | None
|
||||
tools: list[BaseTool | dict]
|
||||
@@ -94,6 +94,24 @@ class ModelRequest:
|
||||
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate and coerce system_prompt to SystemMessage with deprecation warning."""
|
||||
import warnings
|
||||
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
if isinstance(self.system_prompt, str):
|
||||
warnings.warn(
|
||||
"Passing a string for `system_prompt` in `ModelRequest` is deprecated. "
|
||||
"Please use a `SystemMessage` instance instead. "
|
||||
"String values will be automatically converted to `SystemMessage`, "
|
||||
"but this behavior will be removed in a future version.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
# Coerce string to SystemMessage for backward compatibility
|
||||
object.__setattr__(self, "system_prompt", SystemMessage(content=self.system_prompt))
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Set an attribute with a deprecation warning.
|
||||
|
||||
@@ -132,7 +150,7 @@ class ModelRequest:
|
||||
Supported keys:
|
||||
|
||||
- `model`: `BaseChatModel` instance
|
||||
- `system_prompt`: Optional system prompt string
|
||||
- `system_prompt`: Optional system prompt (`SystemMessage` or `str`)
|
||||
- `messages`: `list` of messages
|
||||
- `tool_choice`: Tool choice configuration
|
||||
- `tools`: `list` of available tools
|
||||
|
||||
@@ -95,7 +95,7 @@ def test_adds_system_prompt_when_none_exists() -> None:
|
||||
# System prompt should be set in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt.content
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
@@ -118,11 +118,11 @@ def test_appends_to_existing_system_prompt() -> None:
|
||||
# System prompt should contain both in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert existing_prompt in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert captured_request.system_prompt.startswith(existing_prompt)
|
||||
assert existing_prompt in captured_request.system_prompt.content
|
||||
assert "write_todos" in captured_request.system_prompt.content
|
||||
assert captured_request.system_prompt.content.startswith(existing_prompt)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == existing_prompt
|
||||
assert request.system_prompt.content == existing_prompt
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -162,9 +162,12 @@ def test_todo_middleware_on_model_call(original_prompt, expected_prompt_prefix)
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt.startswith(expected_prompt_prefix)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == original_prompt
|
||||
assert captured_request.system_prompt.content.startswith(expected_prompt_prefix)
|
||||
# Original request should be unchanged (it's now a SystemMessage)
|
||||
if original_prompt is None:
|
||||
assert request.system_prompt is None
|
||||
else:
|
||||
assert request.system_prompt.content == original_prompt
|
||||
|
||||
|
||||
def test_custom_system_prompt() -> None:
|
||||
@@ -184,7 +187,7 @@ def test_custom_system_prompt() -> None:
|
||||
|
||||
# Should use custom prompt in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_prompt
|
||||
assert captured_request.system_prompt.content == custom_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
@@ -220,9 +223,9 @@ def test_todo_middleware_custom_system_prompt() -> None:
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == f"Original prompt\n\n{custom_system_prompt}"
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == "Original prompt"
|
||||
assert captured_request.system_prompt.content == f"Original prompt\n\n{custom_system_prompt}"
|
||||
# Original request should be unchanged (it's now a SystemMessage)
|
||||
assert request.system_prompt.content == "Original prompt"
|
||||
|
||||
|
||||
def test_custom_tool_description() -> None:
|
||||
@@ -281,7 +284,7 @@ def test_todo_middleware_custom_system_prompt_and_tool_description() -> None:
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
# Check that the modified request passed to handler has the expected prompt
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_system_prompt
|
||||
assert captured_request.system_prompt.content == custom_system_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
@@ -444,7 +447,7 @@ async def test_adds_system_prompt_when_none_exists_async() -> None:
|
||||
# System prompt should be set in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt.content
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
@@ -467,11 +470,11 @@ async def test_appends_to_existing_system_prompt_async() -> None:
|
||||
# System prompt should contain both in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt is not None
|
||||
assert existing_prompt in captured_request.system_prompt
|
||||
assert "write_todos" in captured_request.system_prompt
|
||||
assert captured_request.system_prompt.startswith(existing_prompt)
|
||||
assert existing_prompt in captured_request.system_prompt.content
|
||||
assert "write_todos" in captured_request.system_prompt.content
|
||||
assert captured_request.system_prompt.content.startswith(existing_prompt)
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt == existing_prompt
|
||||
assert request.system_prompt.content == existing_prompt
|
||||
|
||||
|
||||
async def test_custom_system_prompt_async() -> None:
|
||||
@@ -491,7 +494,7 @@ async def test_custom_system_prompt_async() -> None:
|
||||
|
||||
# Should use custom prompt in the modified request passed to handler
|
||||
assert captured_request is not None
|
||||
assert captured_request.system_prompt == custom_prompt
|
||||
assert captured_request.system_prompt.content == custom_prompt
|
||||
# Original request should be unchanged
|
||||
assert request.system_prompt is None
|
||||
|
||||
@@ -512,5 +515,5 @@ async def test_handler_called_with_modified_request_async() -> None:
|
||||
|
||||
assert handler_called["value"]
|
||||
assert received_prompt["value"] is not None
|
||||
assert "Original" in received_prompt["value"]
|
||||
assert "write_todos" in received_prompt["value"]
|
||||
assert "Original" in received_prompt["value"].content
|
||||
assert "write_todos" in received_prompt["value"].content
|
||||
|
||||
126
libs/langchain_v1/tests/unit_tests/agents/test_system_message.py
Normal file
126
libs/langchain_v1/tests/unit_tests/agents/test_system_message.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Tests for system_prompt support in create_agent and ModelRequest."""
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import ModelRequest
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
def test_create_agent_accepts_string_system_prompt():
|
||||
"""Test that create_agent accepts a string system_prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
agent = create_agent(model, system_prompt="You are a helpful assistant")
|
||||
|
||||
# Run the agent to ensure it works
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_create_agent_accepts_system_message():
|
||||
"""Test that create_agent accepts a SystemMessage for system_prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
system_msg = SystemMessage(content="You are a helpful assistant")
|
||||
agent = create_agent(model, system_prompt=system_msg)
|
||||
|
||||
# Run the agent to ensure it works
|
||||
result = agent.invoke({"messages": [HumanMessage(content="Hello")]})
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
def test_model_request_deprecates_string_system_prompt(mock_runtime):
|
||||
"""Test that ModelRequest raises deprecation warning for string system_prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
|
||||
# Expect a deprecation warning when passing a string
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt="You are a helpful assistant",
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": []},
|
||||
runtime=mock_runtime,
|
||||
)
|
||||
|
||||
# Check that a deprecation warning was raised
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "system_prompt" in str(w[0].message)
|
||||
|
||||
# Verify that the string was coerced to SystemMessage
|
||||
assert isinstance(request.system_prompt, SystemMessage)
|
||||
assert request.system_prompt.content == "You are a helpful assistant"
|
||||
|
||||
|
||||
def test_model_request_accepts_system_message(mock_runtime):
|
||||
"""Test that ModelRequest accepts SystemMessage without deprecation warning."""
|
||||
model = FakeToolCallingModel()
|
||||
system_msg = SystemMessage(content="You are a helpful assistant")
|
||||
|
||||
# Should not raise any warning
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=system_msg,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": []},
|
||||
runtime=mock_runtime,
|
||||
)
|
||||
|
||||
# No deprecation warning should be raised
|
||||
assert len(w) == 0
|
||||
|
||||
# Verify that the SystemMessage is preserved
|
||||
assert isinstance(request.system_prompt, SystemMessage)
|
||||
assert request.system_prompt.content == "You are a helpful assistant"
|
||||
|
||||
|
||||
def test_model_request_override_with_string(mock_runtime):
|
||||
"""Test that ModelRequest.override() works with string system_prompt."""
|
||||
model = FakeToolCallingModel()
|
||||
system_msg = SystemMessage(content="Original prompt")
|
||||
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=system_msg,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state={"messages": []},
|
||||
runtime=mock_runtime,
|
||||
)
|
||||
|
||||
# Override with a string - should trigger deprecation warning
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
new_request = request.override(system_prompt="New prompt")
|
||||
|
||||
# Check for deprecation warning
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
|
||||
# Verify the override worked
|
||||
assert isinstance(new_request.system_prompt, SystemMessage)
|
||||
assert new_request.system_prompt.content == "New prompt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime():
|
||||
"""Create a mock runtime for testing."""
|
||||
|
||||
class MockRuntime:
|
||||
pass
|
||||
|
||||
return MockRuntime()
|
||||
Reference in New Issue
Block a user