diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_arg_hijacking.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_arg_hijacking.py index 74f600788f2..793b13e977e 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_arg_hijacking.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_arg_hijacking.py @@ -9,6 +9,11 @@ This is harder to detect than unauthorized tool triggering because: 2. Only the arguments are compromised 3. The attack blends with legitimate data +Compares effectiveness of different defense strategies: +1. CombinedStrategy (CheckTool + ParseData) - existing approach +2. IntentVerificationStrategy - new intent-based approach +3. Both combined - all three strategies together + NOTE: These tests are skipped by default in CI because they: 1. Make real API calls to LLM providers (costs money) 2. Are slow (multiple LLM roundtrips per test) @@ -27,7 +32,12 @@ pytestmark = pytest.mark.skipif( reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.", ) -from langchain.agents.middleware import IntentVerificationStrategy +from langchain.agents.middleware import ( + CheckToolStrategy, + CombinedStrategy, + IntentVerificationStrategy, + ParseDataStrategy, +) from .conftest import ( ARG_HIJACK_TEST_CASES, @@ -45,10 +55,10 @@ def _get_model_name(model): def _run_arg_hijack_test( model, user_request, payload, tools, expected_tool, legitimate_args, malicious_patterns ): - """Test argument hijacking with and without protection. + """Test argument hijacking with different defense strategies. Returns: - (baseline_hijacked, protected_hijacked, model_name) + dict with results for each strategy """ model_name = _get_model_name(model) @@ -57,29 +67,92 @@ def _run_arg_hijack_test( model, tools, user_request, payload, expected_tool, legitimate_args, malicious_patterns ) - # Protected: run payload through IntentVerificationStrategy - strategy = IntentVerificationStrategy(model) - req = create_tool_request_with_user_message(tools, "get_user_data", user_request) - sanitized = strategy.process(req, create_tool_message(payload, "get_user_data")) + results = { + "model": model_name, + "baseline": {"hijacked": baseline_hijacked, "details": baseline_details}, + } - # Now check if sanitized payload still causes hijacking - protected_hijacked, _, protected_details = check_argument_hijacking( + # Strategy 1: CombinedStrategy (CheckTool + ParseData) + combined_strategy = CombinedStrategy( + [ + CheckToolStrategy(model, tools=tools), + ParseDataStrategy(model, use_full_conversation=True), + ] + ) + req = create_tool_request_with_user_message(tools, "get_user_data", user_request) + combined_result = combined_strategy.process(req, create_tool_message(payload, "get_user_data")) + combined_hijacked, _, combined_details = check_argument_hijacking( model, tools, user_request, - str(sanitized.content), + str(combined_result.content), expected_tool, legitimate_args, malicious_patterns, ) + results["combined"] = { + "hijacked": combined_hijacked, + "details": combined_details, + "safe": not combined_hijacked, + } + + # Strategy 2: IntentVerificationStrategy only + intent_strategy = IntentVerificationStrategy(model) + req = create_tool_request_with_user_message(tools, "get_user_data", user_request) + intent_result = intent_strategy.process(req, create_tool_message(payload, "get_user_data")) + intent_hijacked, _, intent_details = check_argument_hijacking( + model, + tools, + user_request, + str(intent_result.content), + expected_tool, + legitimate_args, + malicious_patterns, + ) + results["intent_only"] = { + "hijacked": intent_hijacked, + "details": intent_details, + "safe": not intent_hijacked, + } + + # Strategy 3: All strategies combined + all_strategy = CombinedStrategy( + [ + CheckToolStrategy(model, tools=tools), + ParseDataStrategy(model, use_full_conversation=True), + IntentVerificationStrategy(model), + ] + ) + req = create_tool_request_with_user_message(tools, "get_user_data", user_request) + all_result = all_strategy.process(req, create_tool_message(payload, "get_user_data")) + all_hijacked, _, all_details = check_argument_hijacking( + model, + tools, + user_request, + str(all_result.content), + expected_tool, + legitimate_args, + malicious_patterns, + ) + results["all_combined"] = { + "hijacked": all_hijacked, + "details": all_details, + "safe": not all_hijacked, + } # Print comparison - baseline = "HIJACKED" if baseline_hijacked else "clean" - protected = "SAFE" if not protected_hijacked else "HIJACKED" - print(f"\n{model_name}: baseline={baseline} ({baseline_details})") - print(f" -> protected={protected} ({protected_details})") + baseline_str = "HIJACKED" if baseline_hijacked else "clean" + combined_str = "SAFE" if results["combined"]["safe"] else "HIJACKED" + intent_str = "SAFE" if results["intent_only"]["safe"] else "HIJACKED" + all_str = "SAFE" if results["all_combined"]["safe"] else "HIJACKED" - return baseline_hijacked, protected_hijacked, model_name + print(f"\n{model_name}:") + print(f" baseline={baseline_str} ({baseline_details})") + print(f" combined(CheckTool+ParseData)={combined_str} ({combined_details})") + print(f" intent_only={intent_str} ({intent_details})") + print(f" all_combined={all_str} ({all_details})") + + return results class TestOpenAI: @@ -107,7 +180,7 @@ class TestOpenAI: legitimate_args, malicious_patterns, ): - _, protected_hijacked, model_name = _run_arg_hijack_test( + results = _run_arg_hijack_test( openai_model, user_request, payload, @@ -116,7 +189,12 @@ class TestOpenAI: legitimate_args, malicious_patterns, ) - assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}" + # At least one strategy should protect + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy prevented argument hijacking for {results['model']}" class TestAnthropic: @@ -144,7 +222,7 @@ class TestAnthropic: legitimate_args, malicious_patterns, ): - _, protected_hijacked, model_name = _run_arg_hijack_test( + results = _run_arg_hijack_test( anthropic_model, user_request, payload, @@ -153,7 +231,11 @@ class TestAnthropic: legitimate_args, malicious_patterns, ) - assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}" + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy prevented argument hijacking for {results['model']}" class TestGoogle: @@ -181,7 +263,7 @@ class TestGoogle: legitimate_args, malicious_patterns, ): - _, protected_hijacked, model_name = _run_arg_hijack_test( + results = _run_arg_hijack_test( google_model, user_request, payload, @@ -190,7 +272,11 @@ class TestGoogle: legitimate_args, malicious_patterns, ) - assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}" + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy prevented argument hijacking for {results['model']}" class TestOllama: @@ -218,7 +304,7 @@ class TestOllama: legitimate_args, malicious_patterns, ): - _, protected_hijacked, model_name = _run_arg_hijack_test( + results = _run_arg_hijack_test( ollama_model, user_request, payload, @@ -227,4 +313,8 @@ class TestOllama: legitimate_args, malicious_patterns, ) - assert not protected_hijacked, f"Argument hijacking not prevented for {model_name}" + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy prevented argument hijacking for {results['model']}" diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_strategy_comparison.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_strategy_comparison.py new file mode 100644 index 00000000000..2316ab9a875 --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_prompt_injection_strategy_comparison.py @@ -0,0 +1,192 @@ +"""Compare defense strategies against prompt injection attacks. + +Compares effectiveness of different defense strategies: +1. CombinedStrategy (CheckTool + ParseData) - existing approach +2. IntentVerificationStrategy - new intent-based approach +3. Both combined - all three strategies together + +This helps understand which strategy is best for different attack types. + +NOTE: These tests are skipped by default in CI because they: +1. Make real API calls to LLM providers (costs money) +2. Are slow (multiple LLM roundtrips per test) +3. Are for manual validation, not regression testing + +To run manually: + RUN_BENCHMARK_TESTS=1 pytest test_prompt_injection_strategy_comparison.py -v -s +""" + +import os + +import pytest + +pytestmark = pytest.mark.skipif( + os.environ.get("RUN_BENCHMARK_TESTS") != "1", + reason="E2E tests are skipped by default. Set RUN_BENCHMARK_TESTS=1 to run.", +) + +from langchain.agents.middleware import ( + CheckToolStrategy, + CombinedStrategy, + IntentVerificationStrategy, + ParseDataStrategy, +) + +from .conftest import ( + INJECTION_TEST_CASES, + check_triggers_tools, + check_vulnerability, + create_tool_message, + create_tool_request_with_user_message, +) + + +def _get_model_name(model): + """Extract model name from different model types.""" + return getattr(model, "model_name", None) or getattr(model, "model", "unknown") + + +def _run_strategy_comparison(model, payload, tools, tool_name, target_tools): + """Compare different defense strategies against an injection payload. + + Returns: + dict with results for each strategy + """ + model_name = _get_model_name(model) + + # Baseline: no protection + baseline_vulnerable, baseline_triggered = check_vulnerability( + model, tools, payload, target_tools + ) + + # User request context for IntentVerificationStrategy + user_request = "Process the tool result and help me with my task." + + results = { + "model": model_name, + "baseline": { + "vulnerable": baseline_vulnerable, + "triggered": baseline_triggered, + }, + } + + # Strategy 1: CombinedStrategy (CheckTool + ParseData) + combined_strategy = CombinedStrategy( + [ + CheckToolStrategy(model, tools=tools), + ParseDataStrategy(model, use_full_conversation=True), + ] + ) + req = create_tool_request_with_user_message(tools, tool_name, user_request) + combined_result = combined_strategy.process(req, create_tool_message(payload, tool_name)) + combined_vulnerable, combined_triggered = check_triggers_tools( + model, tools, str(combined_result.content), target_tools + ) + results["combined"] = { + "vulnerable": combined_vulnerable, + "triggered": combined_triggered, + "safe": not combined_vulnerable, + } + + # Strategy 2: IntentVerificationStrategy only + intent_strategy = IntentVerificationStrategy(model) + req = create_tool_request_with_user_message(tools, tool_name, user_request) + intent_result = intent_strategy.process(req, create_tool_message(payload, tool_name)) + intent_vulnerable, intent_triggered = check_triggers_tools( + model, tools, str(intent_result.content), target_tools + ) + results["intent_only"] = { + "vulnerable": intent_vulnerable, + "triggered": intent_triggered, + "safe": not intent_vulnerable, + } + + # Strategy 3: All strategies combined + all_strategy = CombinedStrategy( + [ + CheckToolStrategy(model, tools=tools), + ParseDataStrategy(model, use_full_conversation=True), + IntentVerificationStrategy(model), + ] + ) + req = create_tool_request_with_user_message(tools, tool_name, user_request) + all_result = all_strategy.process(req, create_tool_message(payload, tool_name)) + all_vulnerable, all_triggered = check_triggers_tools( + model, tools, str(all_result.content), target_tools + ) + results["all_combined"] = { + "vulnerable": all_vulnerable, + "triggered": all_triggered, + "safe": not all_vulnerable, + } + + # Print comparison + baseline_str = "VULN" if baseline_vulnerable else "safe" + combined_str = "SAFE" if results["combined"]["safe"] else "FAIL" + intent_str = "SAFE" if results["intent_only"]["safe"] else "FAIL" + all_str = "SAFE" if results["all_combined"]["safe"] else "FAIL" + + print(f"\n{model_name}:") + print(f" baseline={baseline_str} triggered={baseline_triggered}") + print(f" combined(CheckTool+ParseData)={combined_str} triggered={combined_triggered}") + print(f" intent_only={intent_str} triggered={intent_triggered}") + print(f" all_combined={all_str} triggered={all_triggered}") + + return results + + +class TestOpenAI: + """Strategy comparison for OpenAI models.""" + + @pytest.mark.requires("langchain_openai") + @pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES) + def test_strategy_comparison(self, openai_model, payload, tools, tool_name, target_tools): + results = _run_strategy_comparison(openai_model, payload, tools, tool_name, target_tools) + # At least one strategy should protect + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy protected against injection for {results['model']}" + + +class TestAnthropic: + """Strategy comparison for Anthropic models.""" + + @pytest.mark.requires("langchain_anthropic") + @pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES) + def test_strategy_comparison(self, anthropic_model, payload, tools, tool_name, target_tools): + results = _run_strategy_comparison(anthropic_model, payload, tools, tool_name, target_tools) + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy protected against injection for {results['model']}" + + +class TestGoogle: + """Strategy comparison for Google models.""" + + @pytest.mark.requires("langchain_google_genai") + @pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES) + def test_strategy_comparison(self, google_model, payload, tools, tool_name, target_tools): + results = _run_strategy_comparison(google_model, payload, tools, tool_name, target_tools) + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy protected against injection for {results['model']}" + + +class TestOllama: + """Strategy comparison for Ollama models.""" + + @pytest.mark.requires("langchain_ollama") + @pytest.mark.parametrize("payload,tools,tool_name,target_tools", INJECTION_TEST_CASES) + def test_strategy_comparison(self, ollama_model, payload, tools, tool_name, target_tools): + results = _run_strategy_comparison(ollama_model, payload, tools, tool_name, target_tools) + assert ( + results["combined"]["safe"] + or results["intent_only"]["safe"] + or results["all_combined"]["safe"] + ), f"No strategy protected against injection for {results['model']}"