feat(tests): compare all defense strategies in injection tests

Update arg hijacking and add strategy comparison tests to evaluate:
1. CombinedStrategy (CheckTool + ParseData)
2. IntentVerificationStrategy alone
3. All three strategies combined

Tests pass if any strategy successfully defends against the attack,
helping identify which strategies work best for different attack types.
This commit is contained in:
John Kennedy
2026-01-31 17:39:46 -08:00
parent 0dd205f223
commit b2216bc600
2 changed files with 305 additions and 23 deletions

View File

@@ -9,6 +9,11 @@ This is harder to detect than unauthorized tool triggering because:
2. Only the arguments are compromised
3. The attack blends with legitimate data
Compares effectiveness of different defense strategies:
1. CombinedStrategy (CheckTool + ParseData) - existing approach
2. IntentVerificationStrategy - new intent-based approach
3. Both combined - all three strategies together
NOTE: These tests are skipped by default in CI because they:
1. Make real API calls to LLM providers (costs money)
2. Are slow (multiple LLM roundtrips per test)
@@ -27,7 +32,12 @@ pytestmark = pytest.mark.skipif(
reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.",
)
from langchain.agents.middleware import IntentVerificationStrategy
from langchain.agents.middleware import (
CheckToolStrategy,
CombinedStrategy,
IntentVerificationStrategy,
ParseDataStrategy,
)
from .conftest import (
ARG_HIJACK_TEST_CASES,
@@ -45,10 +55,10 @@ def _get_model_name(model):
def _run_arg_hijack_test(
model, user_request, payload, tools, expected_tool, legitimate_args, malicious_patterns
):
"""Test argument hijacking with and without protection.
"""Test argument hijacking with different defense strategies.
Returns:
(baseline_hijacked, protected_hijacked, model_name)
dict with results for each strategy
"""
model_name = _get_model_name(model)
@@ -57,29 +67,92 @@ def _run_arg_hijack_test(
model, tools, user_request, payload, expected_tool, legitimate_args, malicious_patterns
)
# Protected: run payload through IntentVerificationStrategy
strategy = IntentVerificationStrategy(model)
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
sanitized = strategy.process(req, create_tool_message(payload, "get_user_data"))
results = {
"model": model_name,
"baseline": {"hijacked": baseline_hijacked, "details": baseline_details},
}
# Now check if sanitized payload still causes hijacking
protected_hijacked, _, protected_details = check_argument_hijacking(
# Strategy 1: CombinedStrategy (CheckTool + ParseData)
combined_strategy = CombinedStrategy(
[
CheckToolStrategy(model, tools=tools),
ParseDataStrategy(model, use_full_conversation=True),
]
)
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
combined_result = combined_strategy.process(req, create_tool_message(payload, "get_user_data"))
combined_hijacked, _, combined_details = check_argument_hijacking(
model,
tools,
user_request,
str(sanitized.content),
str(combined_result.content),
expected_tool,
legitimate_args,
malicious_patterns,
)
results["combined"] = {
"hijacked": combined_hijacked,
"details": combined_details,
"safe": not combined_hijacked,
}
# Strategy 2: IntentVerificationStrategy only
intent_strategy = IntentVerificationStrategy(model)
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
intent_result = intent_strategy.process(req, create_tool_message(payload, "get_user_data"))
intent_hijacked, _, intent_details = check_argument_hijacking(
model,
tools,
user_request,
str(intent_result.content),
expected_tool,
legitimate_args,
malicious_patterns,
)
results["intent_only"] = {
"hijacked": intent_hijacked,
"details": intent_details,
"safe": not intent_hijacked,
}
# Strategy 3: All strategies combined
all_strategy = CombinedStrategy(
[
CheckToolStrategy(model, tools=tools),
ParseDataStrategy(model, use_full_conversation=True),
IntentVerificationStrategy(model),
]
)
req = create_tool_request_with_user_message(tools, "get_user_data", user_request)
all_result = all_strategy.process(req, create_tool_message(payload, "get_user_data"))
all_hijacked, _, all_details = check_argument_hijacking(
model,
tools,
user_request,
str(all_result.content),
expected_tool,
legitimate_args,
malicious_patterns,
)
results["all_combined"] = {
"hijacked": all_hijacked,
"details": all_details,
"safe": not all_hijacked,
}
# Print comparison
baseline = "HIJACKED" if baseline_hijacked else "clean"
protected = "SAFE" if not protected_hijacked else "HIJACKED"
print(f"\n{model_name}: baseline={baseline} ({baseline_details})")
print(f" -> protected={protected} ({protected_details})")
baseline_str = "HIJACKED" if baseline_hijacked else "clean"
combined_str = "SAFE" if results["combined"]["safe"] else "HIJACKED"
intent_str = "SAFE" if results["intent_only"]["safe"] else "HIJACKED"
all_str = "SAFE" if results["all_combined"]["safe"] else "HIJACKED"
return baseline_hijacked, protected_hijacked, model_name
print(f"\n{model_name}:")
print(f" baseline={baseline_str} ({baseline_details})")
print(f" combined(CheckTool+ParseData)={combined_str} ({combined_details})")
print(f" intent_only={intent_str} ({intent_details})")
print(f" all_combined={all_str} ({all_details})")
return results
class TestOpenAI:
@@ -107,7 +180,7 @@ class TestOpenAI:
legitimate_args,
malicious_patterns,
):
_, protected_hijacked, model_name = _run_arg_hijack_test(
results = _run_arg_hijack_test(
openai_model,
user_request,
payload,
@@ -116,7 +189,12 @@ class TestOpenAI:
legitimate_args,
malicious_patterns,
)
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
# At least one strategy should protect
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy prevented argument hijacking for {results['model']}"
class TestAnthropic:
@@ -144,7 +222,7 @@ class TestAnthropic:
legitimate_args,
malicious_patterns,
):
_, protected_hijacked, model_name = _run_arg_hijack_test(
results = _run_arg_hijack_test(
anthropic_model,
user_request,
payload,
@@ -153,7 +231,11 @@ class TestAnthropic:
legitimate_args,
malicious_patterns,
)
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy prevented argument hijacking for {results['model']}"
class TestGoogle:
@@ -181,7 +263,7 @@ class TestGoogle:
legitimate_args,
malicious_patterns,
):
_, protected_hijacked, model_name = _run_arg_hijack_test(
results = _run_arg_hijack_test(
google_model,
user_request,
payload,
@@ -190,7 +272,11 @@ class TestGoogle:
legitimate_args,
malicious_patterns,
)
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy prevented argument hijacking for {results['model']}"
class TestOllama:
@@ -218,7 +304,7 @@ class TestOllama:
legitimate_args,
malicious_patterns,
):
_, protected_hijacked, model_name = _run_arg_hijack_test(
results = _run_arg_hijack_test(
ollama_model,
user_request,
payload,
@@ -227,4 +313,8 @@ class TestOllama:
legitimate_args,
malicious_patterns,
)
assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}"
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy prevented argument hijacking for {results['model']}"

View File

@@ -0,0 +1,192 @@
"""Compare defense strategies against prompt injection attacks.
Compares effectiveness of different defense strategies:
1. CombinedStrategy (CheckTool + ParseData) - existing approach
2. IntentVerificationStrategy - new intent-based approach
3. Both combined - all three strategies together
This helps understand which strategy is best for different attack types.
NOTE: These tests are skipped by default in CI because they:
1. Make real API calls to LLM providers (costs money)
2. Are slow (multiple LLM roundtrips per test)
3. Are for manual validation, not regression testing
To run manually:
RUN_BENCHMARK_TESTS=1 pytest test_prompt_injection_strategy_comparison.py -v -s
"""
import os
import pytest
pytestmark = pytest.mark.skipif(
os.environ.get("RUN_BENCHMARK_TESTS") != "1",
reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.",
)
from langchain.agents.middleware import (
CheckToolStrategy,
CombinedStrategy,
IntentVerificationStrategy,
ParseDataStrategy,
)
from .conftest import (
INJECTION_TEST_CASES,
check_triggers_tools,
check_vulnerability,
create_tool_message,
create_tool_request_with_user_message,
)
def _get_model_name(model):
"""Extract model name from different model types."""
return getattr(model, "model_name", None) or getattr(model, "model", "unknown")
def _run_strategy_comparison(model, payload, tools, tool_name, target_tools):
"""Compare different defense strategies against an injection payload.
Returns:
dict with results for each strategy
"""
model_name = _get_model_name(model)
# Baseline: no protection
baseline_vulnerable, baseline_triggered = check_vulnerability(
model, tools, payload, target_tools
)
# User request context for IntentVerificationStrategy
user_request = "Process the tool result and help me with my task."
results = {
"model": model_name,
"baseline": {
"vulnerable": baseline_vulnerable,
"triggered": baseline_triggered,
},
}
# Strategy 1: CombinedStrategy (CheckTool + ParseData)
combined_strategy = CombinedStrategy(
[
CheckToolStrategy(model, tools=tools),
ParseDataStrategy(model, use_full_conversation=True),
]
)
req = create_tool_request_with_user_message(tools, tool_name, user_request)
combined_result = combined_strategy.process(req, create_tool_message(payload, tool_name))
combined_vulnerable, combined_triggered = check_triggers_tools(
model, tools, str(combined_result.content), target_tools
)
results["combined"] = {
"vulnerable": combined_vulnerable,
"triggered": combined_triggered,
"safe": not combined_vulnerable,
}
# Strategy 2: IntentVerificationStrategy only
intent_strategy = IntentVerificationStrategy(model)
req = create_tool_request_with_user_message(tools, tool_name, user_request)
intent_result = intent_strategy.process(req, create_tool_message(payload, tool_name))
intent_vulnerable, intent_triggered = check_triggers_tools(
model, tools, str(intent_result.content), target_tools
)
results["intent_only"] = {
"vulnerable": intent_vulnerable,
"triggered": intent_triggered,
"safe": not intent_vulnerable,
}
# Strategy 3: All strategies combined
all_strategy = CombinedStrategy(
[
CheckToolStrategy(model, tools=tools),
ParseDataStrategy(model, use_full_conversation=True),
IntentVerificationStrategy(model),
]
)
req = create_tool_request_with_user_message(tools, tool_name, user_request)
all_result = all_strategy.process(req, create_tool_message(payload, tool_name))
all_vulnerable, all_triggered = check_triggers_tools(
model, tools, str(all_result.content), target_tools
)
results["all_combined"] = {
"vulnerable": all_vulnerable,
"triggered": all_triggered,
"safe": not all_vulnerable,
}
# Print comparison
baseline_str = "VULN" if baseline_vulnerable else "safe"
combined_str = "SAFE" if results["combined"]["safe"] else "FAIL"
intent_str = "SAFE" if results["intent_only"]["safe"] else "FAIL"
all_str = "SAFE" if results["all_combined"]["safe"] else "FAIL"
print(f"\n{model_name}:")
print(f" baseline={baseline_str} triggered={baseline_triggered}")
print(f" combined(CheckTool+ParseData)={combined_str} triggered={combined_triggered}")
print(f" intent_only={intent_str} triggered={intent_triggered}")
print(f" all_combined={all_str} triggered={all_triggered}")
return results
class TestOpenAI:
"""Strategy comparison for OpenAI models."""
@pytest.mark.requires("langchain_openai")
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
def test_strategy_comparison(self, openai_model, payload, tools, tool_name, target_tools):
results = _run_strategy_comparison(openai_model, payload, tools, tool_name, target_tools)
# At least one strategy should protect
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy protected against injection for {results['model']}"
class TestAnthropic:
"""Strategy comparison for Anthropic models."""
@pytest.mark.requires("langchain_anthropic")
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
def test_strategy_comparison(self, anthropic_model, payload, tools, tool_name, target_tools):
results = _run_strategy_comparison(anthropic_model, payload, tools, tool_name, target_tools)
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy protected against injection for {results['model']}"
class TestGoogle:
"""Strategy comparison for Google models."""
@pytest.mark.requires("langchain_google_genai")
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
def test_strategy_comparison(self, google_model, payload, tools, tool_name, target_tools):
results = _run_strategy_comparison(google_model, payload, tools, tool_name, target_tools)
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy protected against injection for {results['model']}"
class TestOllama:
"""Strategy comparison for Ollama models."""
@pytest.mark.requires("langchain_ollama")
@pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES)
def test_strategy_comparison(self, ollama_model, payload, tools, tool_name, target_tools):
results = _run_strategy_comparison(ollama_model, payload, tools, tool_name, target_tools)
assert (
results["combined"]["safe"]
or results["intent_only"]["safe"]
or results["all_combined"]["safe"]
), f"No strategy protected against injection for {results['model']}"