mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
feat(tests): compare all defense strategies in injection tests
Update arg hijacking and add strategy comparison tests to evaluate: 1. CombinedStrategy (CheckTool + ParseData) 2. IntentVerificationStrategy alone 3. All three strategies combined Tests pass if any strategy successfully defends against the attack, helping identify which strategies work best for different attack types.
This commit is contained in:
@@ -9,6 +9,11 @@ This is harder to detect than unauthorized tool triggering because:
|
||||
2. Only the arguments are compromised
|
||||
3. The attack blends with legitimate data
|
||||
|
||||
Compares effectiveness of different defense strategies:
|
||||
1. CombinedStrategy (CheckTool + ParseData) - existing approach
|
||||
2. IntentVerificationStrategy - new intent-based approach
|
||||
3. Both combined - all three strategies together
|
||||
|
||||
NOTE: These tests are skipped by default in CI because they:
|
||||
1. Make real API calls to LLM providers (costs money)
|
||||
2. Are slow (multiple LLM roundtrips per test)
|
||||
@@ -27,7 +32,12 @@ pytestmark = pytest.mark.skipif(
|
||||
reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.",
|
||||
)
|
||||
|
||||
from langchain.agents.middleware import IntentVerificationStrategy
|
||||
from langchain.agents.middleware import (
|
||||
CheckToolStrategy,
|
||||
CombinedStrategy,
|
||||
IntentVerificationStrategy,
|
||||
ParseDataStrategy,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
ARG_HIJACK_TEST_CASES,
|
||||
@@ -45,10 +55,10 @@ def _get_model_name(model):
|
||||
def _run_arg_hijack_test(
|
||||
model, user_request, payload, tools, expected_tool, legitimate_args, malicious_patterns
|
||||
):
|
||||
"""Test argument hijacking with and without protection.
|
||||
"""Test argument hijacking with different defense strategies.
|
||||
|
||||
Returns:
|
||||
(baseline_hijacked, protected_hijacked, model_name)
|
||||
dict with results for each strategy
|
||||
"""
|
||||
model_name = _get_model_name(model)
|
||||
|
||||
@@ -57,29 +67,92 @@ def _run_arg_hijack_test(
|
||||
model, tools, user_request, payload, expected_tool, legitimate_args, malicious_patterns
|
||||
)
|
||||
|
||||
# Protected: run payload through IntentVerificationStrategy
|
||||
strategy = IntentVerificationStrategy(model)
|
||||
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
|
||||
sanitized = strategy.process(req, create_tool_message(payload, "get_user_data"))
|
||||
results = {
|
||||
"model": model_name,
|
||||
"baseline": {"hijacked": baseline_hijacked, "details": baseline_details},
|
||||
}
|
||||
|
||||
# Now check if sanitized payload still causes hijacking
|
||||
protected_hijacked, _, protected_details = check_argument_hijacking(
|
||||
# Strategy 1: CombinedStrategy (CheckTool + ParseData)
|
||||
combined_strategy = CombinedStrategy(
|
||||
[
|
||||
CheckToolStrategy(model, tools=tools),
|
||||
ParseDataStrategy(model, use_full_conversation=True),
|
||||
]
|
||||
)
|
||||
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
|
||||
combined_result = combined_strategy.process(req, create_tool_message(payload, "get_user_data"))
|
||||
combined_hijacked, _, combined_details = check_argument_hijacking(
|
||||
model,
|
||||
tools,
|
||||
user_request,
|
||||
str(sanitized.content),
|
||||
str(combined_result.content),
|
||||
expected_tool,
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
results["combined"] = {
|
||||
"hijacked": combined_hijacked,
|
||||
"details": combined_details,
|
||||
"safe": not combined_hijacked,
|
||||
}
|
||||
|
||||
# Strategy 2: IntentVerificationStrategy only
|
||||
intent_strategy = IntentVerificationStrategy(model)
|
||||
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
|
||||
intent_result = intent_strategy.process(req, create_tool_message(payload, "get_user_data"))
|
||||
intent_hijacked, _, intent_details = check_argument_hijacking(
|
||||
model,
|
||||
tools,
|
||||
user_request,
|
||||
str(intent_result.content),
|
||||
expected_tool,
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
results["intent_only"] = {
|
||||
"hijacked": intent_hijacked,
|
||||
"details": intent_details,
|
||||
"safe": not intent_hijacked,
|
||||
}
|
||||
|
||||
# Strategy 3: All strategies combined
|
||||
all_strategy = CombinedStrategy(
|
||||
[
|
||||
CheckToolStrategy(model, tools=tools),
|
||||
ParseDataStrategy(model, use_full_conversation=True),
|
||||
IntentVerificationStrategy(model),
|
||||
]
|
||||
)
|
||||
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
|
||||
all_result = all_strategy.process(req, create_tool_message(payload, "get_user_data"))
|
||||
all_hijacked, _, all_details = check_argument_hijacking(
|
||||
model,
|
||||
tools,
|
||||
user_request,
|
||||
str(all_result.content),
|
||||
expected_tool,
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
results["all_combined"] = {
|
||||
"hijacked": all_hijacked,
|
||||
"details": all_details,
|
||||
"safe": not all_hijacked,
|
||||
}
|
||||
|
||||
# Print comparison
|
||||
baseline = "HIJACKED" if baseline_hijacked else "clean"
|
||||
protected = "SAFE" if not protected_hijacked else "HIJACKED"
|
||||
print(f"\n{model_name}: baseline={baseline} ({baseline_details})")
|
||||
print(f" -> protected={protected} ({protected_details})")
|
||||
baseline_str = "HIJACKED" if baseline_hijacked else "clean"
|
||||
combined_str = "SAFE" if results["combined"]["safe"] else "HIJACKED"
|
||||
intent_str = "SAFE" if results["intent_only"]["safe"] else "HIJACKED"
|
||||
all_str = "SAFE" if results["all_combined"]["safe"] else "HIJACKED"
|
||||
|
||||
return baseline_hijacked, protected_hijacked, model_name
|
||||
print(f"\n{model_name}:")
|
||||
print(f" baseline={baseline_str} ({baseline_details})")
|
||||
print(f" combined(CheckTool+ParseData)={combined_str} ({combined_details})")
|
||||
print(f" intent_only={intent_str} ({intent_details})")
|
||||
print(f" all_combined={all_str} ({all_details})")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@@ -107,7 +180,7 @@ class TestOpenAI:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
):
|
||||
_, protected_hijacked, model_name = _run_arg_hijack_test(
|
||||
results = _run_arg_hijack_test(
|
||||
openai_model,
|
||||
user_request,
|
||||
payload,
|
||||
@@ -116,7 +189,12 @@ class TestOpenAI:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
|
||||
# At least one strategy should protect
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy prevented argument hijacking for {results['model']}"
|
||||
|
||||
|
||||
class TestAnthropic:
|
||||
@@ -144,7 +222,7 @@ class TestAnthropic:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
):
|
||||
_, protected_hijacked, model_name = _run_arg_hijack_test(
|
||||
results = _run_arg_hijack_test(
|
||||
anthropic_model,
|
||||
user_request,
|
||||
payload,
|
||||
@@ -153,7 +231,11 @@ class TestAnthropic:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy prevented argument hijacking for {results['model']}"
|
||||
|
||||
|
||||
class TestGoogle:
|
||||
@@ -181,7 +263,7 @@ class TestGoogle:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
):
|
||||
_, protected_hijacked, model_name = _run_arg_hijack_test(
|
||||
results = _run_arg_hijack_test(
|
||||
google_model,
|
||||
user_request,
|
||||
payload,
|
||||
@@ -190,7 +272,11 @@ class TestGoogle:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy prevented argument hijacking for {results['model']}"
|
||||
|
||||
|
||||
class TestOllama:
|
||||
@@ -218,7 +304,7 @@ class TestOllama:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
):
|
||||
_, protected_hijacked, model_name = _run_arg_hijack_test(
|
||||
results = _run_arg_hijack_test(
|
||||
ollama_model,
|
||||
user_request,
|
||||
payload,
|
||||
@@ -227,4 +313,8 @@ class TestOllama:
|
||||
legitimate_args,
|
||||
malicious_patterns,
|
||||
)
|
||||
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy prevented argument hijacking for {results['model']}"
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Compare defense strategies against prompt injection attacks.
|
||||
|
||||
Compares effectiveness of different defense strategies:
|
||||
1. CombinedStrategy (CheckTool + ParseData) - existing approach
|
||||
2. IntentVerificationStrategy - new intent-based approach
|
||||
3. Both combined - all three strategies together
|
||||
|
||||
This helps understand which strategy is best for different attack types.
|
||||
|
||||
NOTE: These tests are skipped by default in CI because they:
|
||||
1. Make real API calls to LLM providers (costs money)
|
||||
2. Are slow (multiple LLM roundtrips per test)
|
||||
3. Are for manual validation, not regression testing
|
||||
|
||||
To run manually:
|
||||
RUN_BENCHMARK_TESTS=1 pytest test_prompt_injection_strategy_comparison.py -v -s
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("RUN_BENCHMARK_TESTS") != "1",
|
||||
reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.",
|
||||
)
|
||||
|
||||
from langchain.agents.middleware import (
|
||||
CheckToolStrategy,
|
||||
CombinedStrategy,
|
||||
IntentVerificationStrategy,
|
||||
ParseDataStrategy,
|
||||
)
|
||||
|
||||
from .conftest import (
|
||||
INJECTION_TEST_CASES,
|
||||
check_triggers_tools,
|
||||
check_vulnerability,
|
||||
create_tool_message,
|
||||
create_tool_request_with_user_message,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_name(model):
|
||||
"""Extract model name from different model types."""
|
||||
return getattr(model, "model_name", None) or getattr(model, "model", "unknown")
|
||||
|
||||
|
||||
def _run_strategy_comparison(model, payload, tools, tool_name, target_tools):
|
||||
"""Compare different defense strategies against an injection payload.
|
||||
|
||||
Returns:
|
||||
dict with results for each strategy
|
||||
"""
|
||||
model_name = _get_model_name(model)
|
||||
|
||||
# Baseline: no protection
|
||||
baseline_vulnerable, baseline_triggered = check_vulnerability(
|
||||
model, tools, payload, target_tools
|
||||
)
|
||||
|
||||
# User request context for IntentVerificationStrategy
|
||||
user_request = "Process the tool result and help me with my task."
|
||||
|
||||
results = {
|
||||
"model": model_name,
|
||||
"baseline": {
|
||||
"vulnerable": baseline_vulnerable,
|
||||
"triggered": baseline_triggered,
|
||||
},
|
||||
}
|
||||
|
||||
# Strategy 1: CombinedStrategy (CheckTool + ParseData)
|
||||
combined_strategy = CombinedStrategy(
|
||||
[
|
||||
CheckToolStrategy(model, tools=tools),
|
||||
ParseDataStrategy(model, use_full_conversation=True),
|
||||
]
|
||||
)
|
||||
req = create_tool_request_with_user_message(tools, tool_name, user_request)
|
||||
combined_result = combined_strategy.process(req, create_tool_message(payload, tool_name))
|
||||
combined_vulnerable, combined_triggered = check_triggers_tools(
|
||||
model, tools, str(combined_result.content), target_tools
|
||||
)
|
||||
results["combined"] = {
|
||||
"vulnerable": combined_vulnerable,
|
||||
"triggered": combined_triggered,
|
||||
"safe": not combined_vulnerable,
|
||||
}
|
||||
|
||||
# Strategy 2: IntentVerificationStrategy only
|
||||
intent_strategy = IntentVerificationStrategy(model)
|
||||
req = create_tool_request_with_user_message(tools, tool_name, user_request)
|
||||
intent_result = intent_strategy.process(req, create_tool_message(payload, tool_name))
|
||||
intent_vulnerable, intent_triggered = check_triggers_tools(
|
||||
model, tools, str(intent_result.content), target_tools
|
||||
)
|
||||
results["intent_only"] = {
|
||||
"vulnerable": intent_vulnerable,
|
||||
"triggered": intent_triggered,
|
||||
"safe": not intent_vulnerable,
|
||||
}
|
||||
|
||||
# Strategy 3: All strategies combined
|
||||
all_strategy = CombinedStrategy(
|
||||
[
|
||||
CheckToolStrategy(model, tools=tools),
|
||||
ParseDataStrategy(model, use_full_conversation=True),
|
||||
IntentVerificationStrategy(model),
|
||||
]
|
||||
)
|
||||
req = create_tool_request_with_user_message(tools, tool_name, user_request)
|
||||
all_result = all_strategy.process(req, create_tool_message(payload, tool_name))
|
||||
all_vulnerable, all_triggered = check_triggers_tools(
|
||||
model, tools, str(all_result.content), target_tools
|
||||
)
|
||||
results["all_combined"] = {
|
||||
"vulnerable": all_vulnerable,
|
||||
"triggered": all_triggered,
|
||||
"safe": not all_vulnerable,
|
||||
}
|
||||
|
||||
# Print comparison
|
||||
baseline_str = "VULN" if baseline_vulnerable else "safe"
|
||||
combined_str = "SAFE" if results["combined"]["safe"] else "FAIL"
|
||||
intent_str = "SAFE" if results["intent_only"]["safe"] else "FAIL"
|
||||
all_str = "SAFE" if results["all_combined"]["safe"] else "FAIL"
|
||||
|
||||
print(f"\n{model_name}:")
|
||||
print(f" baseline={baseline_str} triggered={baseline_triggered}")
|
||||
print(f" combined(CheckTool+ParseData)={combined_str} triggered={combined_triggered}")
|
||||
print(f" intent_only={intent_str} triggered={intent_triggered}")
|
||||
print(f" all_combined={all_str} triggered={all_triggered}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
"""Strategy comparison for OpenAI models."""
|
||||
|
||||
@pytest.mark.requires("langchain_openai")
|
||||
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
|
||||
def test_strategy_comparison(self, openai_model, payload, tools, tool_name, target_tools):
|
||||
results = _run_strategy_comparison(openai_model, payload, tools, tool_name, target_tools)
|
||||
# At least one strategy should protect
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy protected against injection for {results['model']}"
|
||||
|
||||
|
||||
class TestAnthropic:
|
||||
"""Strategy comparison for Anthropic models."""
|
||||
|
||||
@pytest.mark.requires("langchain_anthropic")
|
||||
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
|
||||
def test_strategy_comparison(self, anthropic_model, payload, tools, tool_name, target_tools):
|
||||
results = _run_strategy_comparison(anthropic_model, payload, tools, tool_name, target_tools)
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy protected against injection for {results['model']}"
|
||||
|
||||
|
||||
class TestGoogle:
|
||||
"""Strategy comparison for Google models."""
|
||||
|
||||
@pytest.mark.requires("langchain_google_genai")
|
||||
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
|
||||
def test_strategy_comparison(self, google_model, payload, tools, tool_name, target_tools):
|
||||
results = _run_strategy_comparison(google_model, payload, tools, tool_name, target_tools)
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy protected against injection for {results['model']}"
|
||||
|
||||
|
||||
class TestOllama:
|
||||
"""Strategy comparison for Ollama models."""
|
||||
|
||||
@pytest.mark.requires("langchain_ollama")
|
||||
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
|
||||
def test_strategy_comparison(self, ollama_model, payload, tools, tool_name, target_tools):
|
||||
results = _run_strategy_comparison(ollama_model, payload, tools, tool_name, target_tools)
|
||||
assert (
|
||||
results["combined"]["safe"]
|
||||
or results["intent_only"]["safe"]
|
||||
or results["all_combined"]["safe"]
|
||||
), f"No strategy protected against injection for {results['model']}"
|
||||
Reference in New Issue
Block a user