diff --git a/libs/langchain_v1/langchain/agents/middleware/pii.py b/libs/langchain_v1/langchain/agents/middleware/pii.py index 9b827cb7163..446f74097fa 100644 --- a/libs/langchain_v1/langchain/agents/middleware/pii.py +++ b/libs/langchain_v1/langchain/agents/middleware/pii.py @@ -360,6 +360,7 @@ class PIIMiddleware(AgentMiddleware): __all__ = [ "PIIDetectionError", + "PIIMatch", "PIIMiddleware", "detect_credit_card", "detect_email", diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py index a3501fd84a8..40c199591e4 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py @@ -1,11 +1,16 @@ """Tests for PII detection middleware.""" +from typing import Any + import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage +from langgraph.runtime import Runtime +from langchain.agents import AgentState from langchain.agents.factory import create_agent from langchain.agents.middleware.pii import ( PIIDetectionError, + PIIMatch, PIIMiddleware, detect_credit_card, detect_email, @@ -23,7 +28,7 @@ from tests.unit_tests.agents.model import FakeToolCallingModel class TestEmailDetection: """Test email detection.""" - def test_detect_valid_email(self): + def test_detect_valid_email(self) -> None: content = "Contact me at john.doe@example.com for more info." matches = detect_email(content) @@ -33,7 +38,7 @@ class TestEmailDetection: assert matches[0]["start"] == 14 assert matches[0]["end"] == 34 - def test_detect_multiple_emails(self): + def test_detect_multiple_emails(self) -> None: content = "Email alice@test.com or bob@company.org" matches = detect_email(content) @@ -41,12 +46,12 @@ class TestEmailDetection: assert matches[0]["value"] == "alice@test.com" assert matches[1]["value"] == "bob@company.org" - def test_no_email(self): + def test_no_email(self) -> None: content = "This text has no email addresses." matches = detect_email(content) assert len(matches) == 0 - def test_invalid_email_format(self): + def test_invalid_email_format(self) -> None: content = "Invalid emails: @test.com, user@, user@domain" matches = detect_email(content) # Should not match invalid formats @@ -56,7 +61,7 @@ class TestEmailDetection: class TestCreditCardDetection: """Test credit card detection with Luhn validation.""" - def test_detect_valid_credit_card(self): + def test_detect_valid_credit_card(self) -> None: # Valid Visa test number content = "Card: 4532015112830366" matches = detect_credit_card(content) @@ -65,7 +70,7 @@ class TestCreditCardDetection: assert matches[0]["type"] == "credit_card" assert matches[0]["value"] == "4532015112830366" - def test_detect_credit_card_with_spaces(self): + def test_detect_credit_card_with_spaces(self) -> None: # Valid Mastercard test number # Add spaces spaced_content = "Card: 5425 2334 3010 9903" @@ -74,19 +79,19 @@ class TestCreditCardDetection: assert len(matches) == 1 assert "5425 2334 3010 9903" in matches[0]["value"] - def test_detect_credit_card_with_dashes(self): + def test_detect_credit_card_with_dashes(self) -> None: content = "Card: 4532-0151-1283-0366" matches = detect_credit_card(content) assert len(matches) == 1 - def test_invalid_luhn_not_detected(self): + def test_invalid_luhn_not_detected(self) -> None: # Invalid Luhn checksum content = "Card: 1234567890123456" matches = detect_credit_card(content) assert len(matches) == 0 - def test_no_credit_card(self): + def test_no_credit_card(self) -> None: content = "No cards here." matches = detect_credit_card(content) assert len(matches) == 0 @@ -95,7 +100,7 @@ class TestCreditCardDetection: class TestIPDetection: """Test IP address detection.""" - def test_detect_valid_ipv4(self): + def test_detect_valid_ipv4(self) -> None: content = "Server IP: 192.168.1.1" matches = detect_ip(content) @@ -103,7 +108,7 @@ class TestIPDetection: assert matches[0]["type"] == "ip" assert matches[0]["value"] == "192.168.1.1" - def test_detect_multiple_ips(self): + def test_detect_multiple_ips(self) -> None: content = "Connect to 10.0.0.1 or 8.8.8.8" matches = detect_ip(content) @@ -111,13 +116,13 @@ class TestIPDetection: assert matches[0]["value"] == "10.0.0.1" assert matches[1]["value"] == "8.8.8.8" - def test_invalid_ip_not_detected(self): + def test_invalid_ip_not_detected(self) -> None: # Out of range octets content = "Not an IP: 999.999.999.999" matches = detect_ip(content) assert len(matches) == 0 - def test_version_number_not_detected(self): + def test_version_number_not_detected(self) -> None: # Version numbers should not be detected as IPs content = "Version 1.2.3.4 released" matches = detect_ip(content) @@ -125,7 +130,7 @@ class TestIPDetection: # This is acceptable behavior assert len(matches) >= 0 - def test_no_ip(self): + def test_no_ip(self) -> None: content = "No IP addresses here." matches = detect_ip(content) assert len(matches) == 0 @@ -134,7 +139,7 @@ class TestIPDetection: class TestMACAddressDetection: """Test MAC address detection.""" - def test_detect_mac_with_colons(self): + def test_detect_mac_with_colons(self) -> None: content = "MAC: 00:1A:2B:3C:4D:5E" matches = detect_mac_address(content) @@ -142,26 +147,26 @@ class TestMACAddressDetection: assert matches[0]["type"] == "mac_address" assert matches[0]["value"] == "00:1A:2B:3C:4D:5E" - def test_detect_mac_with_dashes(self): + def test_detect_mac_with_dashes(self) -> None: content = "MAC: 00-1A-2B-3C-4D-5E" matches = detect_mac_address(content) assert len(matches) == 1 assert matches[0]["value"] == "00-1A-2B-3C-4D-5E" - def test_detect_lowercase_mac(self): + def test_detect_lowercase_mac(self) -> None: content = "MAC: aa:bb:cc:dd:ee:ff" matches = detect_mac_address(content) assert len(matches) == 1 assert matches[0]["value"] == "aa:bb:cc:dd:ee:ff" - def test_no_mac(self): + def test_no_mac(self) -> None: content = "No MAC address here." matches = detect_mac_address(content) assert len(matches) == 0 - def test_partial_mac_not_detected(self): + def test_partial_mac_not_detected(self) -> None: content = "Partial: 00:1A:2B:3C" matches = detect_mac_address(content) assert len(matches) == 0 @@ -170,7 +175,7 @@ class TestMACAddressDetection: class TestURLDetection: """Test URL detection.""" - def test_detect_http_url(self): + def test_detect_http_url(self) -> None: content = "Visit http://example.com for details." matches = detect_url(content) @@ -178,39 +183,39 @@ class TestURLDetection: assert matches[0]["type"] == "url" assert matches[0]["value"] == "http://example.com" - def test_detect_https_url(self): + def test_detect_https_url(self) -> None: content = "Visit https://secure.example.com/path" matches = detect_url(content) assert len(matches) == 1 assert matches[0]["value"] == "https://secure.example.com/path" - def test_detect_www_url(self): + def test_detect_www_url(self) -> None: content = "Check www.example.com" matches = detect_url(content) assert len(matches) == 1 assert matches[0]["value"] == "www.example.com" - def test_detect_bare_domain_with_path(self): + def test_detect_bare_domain_with_path(self) -> None: content = "Go to example.com/page" matches = detect_url(content) assert len(matches) == 1 assert matches[0]["value"] == "example.com/page" - def test_detect_multiple_urls(self): + def test_detect_multiple_urls(self) -> None: content = "Visit http://test.com and https://example.org" matches = detect_url(content) assert len(matches) == 2 - def test_no_url(self): + def test_no_url(self) -> None: content = "No URLs here." matches = detect_url(content) assert len(matches) == 0 - def test_bare_domain_without_path_not_detected(self): + def test_bare_domain_without_path_not_detected(self) -> None: # To reduce false positives, bare domains without paths are not detected content = "The word example.com in prose" detect_url(content) @@ -226,21 +231,21 @@ class TestURLDetection: class TestRedactStrategy: """Test redact strategy.""" - def test_redact_email(self): + def test_redact_email(self) -> None: middleware = PIIMiddleware("email", strategy="redact") - state = {"messages": [HumanMessage("Email me at test@example.com")]} + state = AgentState[Any](messages=[HumanMessage("Email me at test@example.com")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None assert "[REDACTED_EMAIL]" in result["messages"][0].content assert "test@example.com" not in result["messages"][0].content - def test_redact_multiple_pii(self): + def test_redact_multiple_pii(self) -> None: middleware = PIIMiddleware("email", strategy="redact") - state = {"messages": [HumanMessage("Contact alice@test.com or bob@test.com")]} + state = AgentState[Any](messages=[HumanMessage("Contact alice@test.com or bob@test.com")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None content = result["messages"][0].content @@ -252,34 +257,34 @@ class TestRedactStrategy: class TestMaskStrategy: """Test mask strategy.""" - def test_mask_email(self): + def test_mask_email(self) -> None: middleware = PIIMiddleware("email", strategy="mask") - state = {"messages": [HumanMessage("Email: user@example.com")]} + state = AgentState[Any](messages=[HumanMessage("Email: user@example.com")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None content = result["messages"][0].content assert "user@****.com" in content assert "user@example.com" not in content - def test_mask_credit_card(self): + def test_mask_credit_card(self) -> None: middleware = PIIMiddleware("credit_card", strategy="mask") # Valid test card - state = {"messages": [HumanMessage("Card: 4532015112830366")]} + state = AgentState[Any](messages=[HumanMessage("Card: 4532015112830366")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None content = result["messages"][0].content assert "0366" in content # Last 4 digits visible assert "4532015112830366" not in content - def test_mask_ip(self): + def test_mask_ip(self) -> None: middleware = PIIMiddleware("ip", strategy="mask") - state = {"messages": [HumanMessage("IP: 192.168.1.100")]} + state = AgentState[Any](messages=[HumanMessage("IP: 192.168.1.100")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None content = result["messages"][0].content @@ -290,11 +295,11 @@ class TestMaskStrategy: class TestHashStrategy: """Test hash strategy.""" - def test_hash_email(self): + def test_hash_email(self) -> None: middleware = PIIMiddleware("email", strategy="hash") - state = {"messages": [HumanMessage("Email: test@example.com")]} + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None content = result["messages"][0].content @@ -302,39 +307,41 @@ class TestHashStrategy: assert ">" in content assert "test@example.com" not in content - def test_hash_is_deterministic(self): + def test_hash_is_deterministic(self) -> None: middleware = PIIMiddleware("email", strategy="hash") # Same email should produce same hash - state1 = {"messages": [HumanMessage("Email: test@example.com")]} - state2 = {"messages": [HumanMessage("Email: test@example.com")]} + state1 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) + state2 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) - result1 = middleware.before_model(state1, None) - result2 = middleware.before_model(state2, None) + result1 = middleware.before_model(state1, Runtime()) + result2 = middleware.before_model(state2, Runtime()) + assert result1 is not None + assert result2 is not None assert result1["messages"][0].content == result2["messages"][0].content class TestBlockStrategy: """Test block strategy.""" - def test_block_raises_exception(self): + def test_block_raises_exception(self) -> None: middleware = PIIMiddleware("email", strategy="block") - state = {"messages": [HumanMessage("Email: test@example.com")]} + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) with pytest.raises(PIIDetectionError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) assert exc_info.value.pii_type == "email" assert len(exc_info.value.matches) == 1 assert "test@example.com" in exc_info.value.matches[0]["value"] - def test_block_with_multiple_matches(self): + def test_block_with_multiple_matches(self) -> None: middleware = PIIMiddleware("email", strategy="block") - state = {"messages": [HumanMessage("Emails: alice@test.com and bob@test.com")]} + state = AgentState[Any](messages=[HumanMessage("Emails: alice@test.com and bob@test.com")]) with pytest.raises(PIIDetectionError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) assert len(exc_info.value.matches) == 2 @@ -347,81 +354,81 @@ class TestBlockStrategy: class TestPIIMiddlewareIntegration: """Test PIIMiddleware integration with agent.""" - def test_apply_to_input_only(self): + def test_apply_to_input_only(self) -> None: """Test that middleware only processes input when configured.""" middleware = PIIMiddleware( "email", strategy="redact", apply_to_input=True, apply_to_output=False ) # Should process HumanMessage - state = {"messages": [HumanMessage("Email: test@example.com")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) + result = middleware.before_model(state, Runtime()) assert result is not None assert "[REDACTED_EMAIL]" in result["messages"][0].content # Should not process AIMessage - state = {"messages": [AIMessage("My email is ai@example.com")]} - result = middleware.after_model(state, None) + state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")]) + result = middleware.after_model(state, Runtime()) assert result is None - def test_apply_to_output_only(self): + def test_apply_to_output_only(self) -> None: """Test that middleware only processes output when configured.""" middleware = PIIMiddleware( "email", strategy="redact", apply_to_input=False, apply_to_output=True ) # Should not process HumanMessage - state = {"messages": [HumanMessage("Email: test@example.com")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) + result = middleware.before_model(state, Runtime()) assert result is None # Should process AIMessage - state = {"messages": [AIMessage("My email is ai@example.com")]} - result = middleware.after_model(state, None) + state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")]) + result = middleware.after_model(state, Runtime()) assert result is not None assert "[REDACTED_EMAIL]" in result["messages"][0].content - def test_apply_to_both(self): + def test_apply_to_both(self) -> None: """Test that middleware processes both input and output.""" middleware = PIIMiddleware( "email", strategy="redact", apply_to_input=True, apply_to_output=True ) # Should process HumanMessage - state = {"messages": [HumanMessage("Email: test@example.com")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")]) + result = middleware.before_model(state, Runtime()) assert result is not None # Should process AIMessage - state = {"messages": [AIMessage("My email is ai@example.com")]} - result = middleware.after_model(state, None) + state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")]) + result = middleware.after_model(state, Runtime()) assert result is not None - def test_no_pii_returns_none(self): + def test_no_pii_returns_none(self) -> None: """Test that middleware returns None when no PII detected.""" middleware = PIIMiddleware("email", strategy="redact") - state = {"messages": [HumanMessage("No PII here")]} + state = AgentState[Any](messages=[HumanMessage("No PII here")]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is None - def test_empty_messages(self): + def test_empty_messages(self) -> None: """Test that middleware handles empty messages gracefully.""" middleware = PIIMiddleware("email", strategy="redact") - state = {"messages": []} + state = AgentState[Any](messages=[]) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is None - def test_apply_to_tool_results(self): + def test_apply_to_tool_results(self) -> None: """Test that middleware processes tool results when enabled.""" middleware = PIIMiddleware( "email", strategy="redact", apply_to_input=False, apply_to_tool_results=True ) # Simulate a conversation with tool call and result containing PII - state = { - "messages": [ + state = AgentState[Any]( + messages=[ HumanMessage("Search for John"), AIMessage( content="", @@ -429,9 +436,9 @@ class TestPIIMiddlewareIntegration: ), ToolMessage(content="Found: john@example.com", tool_call_id="call_123"), ] - } + ) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None # Check that the tool message was redacted @@ -440,14 +447,14 @@ class TestPIIMiddlewareIntegration: assert "[REDACTED_EMAIL]" in tool_msg.content assert "john@example.com" not in tool_msg.content - def test_apply_to_tool_results_mask_strategy(self): + def test_apply_to_tool_results_mask_strategy(self) -> None: """Test that mask strategy works for tool results.""" middleware = PIIMiddleware( "ip", strategy="mask", apply_to_input=False, apply_to_tool_results=True ) - state = { - "messages": [ + state = AgentState[Any]( + messages=[ HumanMessage("Get server IP"), AIMessage( content="", @@ -455,23 +462,23 @@ class TestPIIMiddlewareIntegration: ), ToolMessage(content="Server IP: 192.168.1.100", tool_call_id="call_456"), ] - } + ) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None tool_msg = result["messages"][2] assert "*.*.*.100" in tool_msg.content assert "192.168.1.100" not in tool_msg.content - def test_apply_to_tool_results_block_strategy(self): + def test_apply_to_tool_results_block_strategy(self) -> None: """Test that block strategy raises error for PII in tool results.""" middleware = PIIMiddleware( "email", strategy="block", apply_to_input=False, apply_to_tool_results=True ) - state = { - "messages": [ + state = AgentState[Any]( + messages=[ HumanMessage("Search for user"), AIMessage( content="", @@ -479,17 +486,17 @@ class TestPIIMiddlewareIntegration: ), ToolMessage(content="User email: sensitive@example.com", tool_call_id="call_789"), ] - } + ) with pytest.raises(PIIDetectionError) as exc_info: - middleware.before_model(state, None) + middleware.before_model(state, Runtime()) assert exc_info.value.pii_type == "email" assert len(exc_info.value.matches) == 1 - def test_with_agent(self): + def test_with_agent(self) -> None: """Test PIIMiddleware integrated with create_agent.""" - model = FakeToolCallingModel(responses=[AIMessage(content="Thanks for sharing!")]) + model = FakeToolCallingModel() agent = create_agent( model=model, @@ -508,7 +515,7 @@ class TestPIIMiddlewareIntegration: class TestCustomDetector: """Test custom detector functionality.""" - def test_custom_regex_detector(self): + def test_custom_regex_detector(self) -> None: # Custom regex for API keys middleware = PIIMiddleware( "api_key", @@ -516,25 +523,25 @@ class TestCustomDetector: strategy="redact", ) - state = {"messages": [HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")]) + result = middleware.before_model(state, Runtime()) assert result is not None assert "[REDACTED_API_KEY]" in result["messages"][0].content - def test_custom_callable_detector(self): + def test_custom_callable_detector(self) -> None: # Custom detector function - def detect_custom(content): + def detect_custom(content: str) -> list[PIIMatch]: matches = [] if "CONFIDENTIAL" in content: idx = content.index("CONFIDENTIAL") matches.append( - { - "type": "confidential", - "value": "CONFIDENTIAL", - "start": idx, - "end": idx + 12, - } + PIIMatch( + type="confidential", + value="CONFIDENTIAL", + start=idx, + end=idx + 12, + ) ) return matches @@ -544,17 +551,17 @@ class TestCustomDetector: strategy="redact", ) - state = {"messages": [HumanMessage("This is CONFIDENTIAL information")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("This is CONFIDENTIAL information")]) + result = middleware.before_model(state, Runtime()) assert result is not None assert "[REDACTED_CONFIDENTIAL]" in result["messages"][0].content - def test_unknown_builtin_type_raises_error(self): + def test_unknown_builtin_type_raises_error(self) -> None: with pytest.raises(ValueError, match="Unknown PII type"): PIIMiddleware("unknown_type", strategy="redact") - def test_custom_type_without_detector_raises_error(self): + def test_custom_type_without_detector_raises_error(self) -> None: with pytest.raises(ValueError, match="Unknown PII type"): PIIMiddleware("custom_type", strategy="redact") @@ -562,18 +569,20 @@ class TestCustomDetector: class TestMultipleMiddleware: """Test using multiple PII middleware instances.""" - def test_sequential_application(self): + def test_sequential_application(self) -> None: """Test that multiple PII types are detected when applied sequentially.""" # First apply email middleware email_middleware = PIIMiddleware("email", strategy="redact") - state = {"messages": [HumanMessage("Email: test@example.com, IP: 192.168.1.1")]} - result1 = email_middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 192.168.1.1")]) + result1 = email_middleware.before_model(state, Runtime()) # Then apply IP middleware to the result ip_middleware = PIIMiddleware("ip", strategy="mask") - state_with_email_redacted = {"messages": result1["messages"]} - result2 = ip_middleware.before_model(state_with_email_redacted, None) + assert result1 is not None + state_with_email_redacted = AgentState[Any](messages=result1["messages"]) + result2 = ip_middleware.before_model(state_with_email_redacted, Runtime()) + assert result2 is not None content = result2["messages"][0].content # Email should be redacted @@ -584,9 +593,9 @@ class TestMultipleMiddleware: assert "*.*.*.1" in content assert "192.168.1.1" not in content - def test_multiple_pii_middleware_with_create_agent(self): + def test_multiple_pii_middleware_with_create_agent(self) -> None: """Test that multiple PIIMiddleware instances work together in create_agent.""" - model = FakeToolCallingModel(responses=[AIMessage(content="Response received")]) + model = FakeToolCallingModel() # Multiple PIIMiddleware instances should work because each has a unique name agent = create_agent( @@ -611,7 +620,7 @@ class TestMultipleMiddleware: # IP should be masked assert "192.168.1.100" not in content - def test_custom_detector_for_multiple_types(self): + def test_custom_detector_for_multiple_types(self) -> None: """Test using a single middleware with custom detector for multiple PII types. This is an alternative to using multiple middleware instances, @@ -619,7 +628,7 @@ class TestMultipleMiddleware: """ # Combine multiple detectors into one - def detect_email_and_ip(content): + def detect_email_and_ip(content: str) -> list[PIIMatch]: return detect_email(content) + detect_ip(content) middleware = PIIMiddleware( @@ -628,9 +637,10 @@ class TestMultipleMiddleware: strategy="redact", ) - state = {"messages": [HumanMessage("Email: test@example.com, IP: 10.0.0.1")]} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 10.0.0.1")]) + result = middleware.before_model(state, Runtime()) + assert result is not None content = result["messages"][0].content assert "test@example.com" not in content assert "10.0.0.1" not in content