chore(langchain): fix types in test_pii (#34617)

This commit is contained in:
Christophe Bornet
2026-01-06 19:06:25 +01:00
committed by GitHub
parent 6537939f53
commit 7f4f130479
2 changed files with 130 additions and 119 deletions

View File

@@ -360,6 +360,7 @@ class PIIMiddleware(AgentMiddleware):
__all__ = [
"PIIDetectionError",
"PIIMatch",
"PIIMiddleware",
"detect_credit_card",
"detect_email",

View File

@@ -1,11 +1,16 @@
"""Tests for PII detection middleware."""
from typing import Any
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langgraph.runtime import Runtime
from langchain.agents import AgentState
from langchain.agents.factory import create_agent
from langchain.agents.middleware.pii import (
PIIDetectionError,
PIIMatch,
PIIMiddleware,
detect_credit_card,
detect_email,
@@ -23,7 +28,7 @@ from tests.unit_tests.agents.model import FakeToolCallingModel
class TestEmailDetection:
"""Test email detection."""
def test_detect_valid_email(self):
def test_detect_valid_email(self) -> None:
content = "Contact me at john.doe@example.com for more info."
matches = detect_email(content)
@@ -33,7 +38,7 @@ class TestEmailDetection:
assert matches[0]["start"] == 14
assert matches[0]["end"] == 34
def test_detect_multiple_emails(self):
def test_detect_multiple_emails(self) -> None:
content = "Email alice@test.com or bob@company.org"
matches = detect_email(content)
@@ -41,12 +46,12 @@ class TestEmailDetection:
assert matches[0]["value"] == "alice@test.com"
assert matches[1]["value"] == "bob@company.org"
def test_no_email(self):
def test_no_email(self) -> None:
content = "This text has no email addresses."
matches = detect_email(content)
assert len(matches) == 0
def test_invalid_email_format(self):
def test_invalid_email_format(self) -> None:
content = "Invalid emails: @test.com, user@, user@domain"
matches = detect_email(content)
# Should not match invalid formats
@@ -56,7 +61,7 @@ class TestEmailDetection:
class TestCreditCardDetection:
"""Test credit card detection with Luhn validation."""
def test_detect_valid_credit_card(self):
def test_detect_valid_credit_card(self) -> None:
# Valid Visa test number
content = "Card: 4532015112830366"
matches = detect_credit_card(content)
@@ -65,7 +70,7 @@ class TestCreditCardDetection:
assert matches[0]["type"] == "credit_card"
assert matches[0]["value"] == "4532015112830366"
def test_detect_credit_card_with_spaces(self):
def test_detect_credit_card_with_spaces(self) -> None:
# Valid Mastercard test number
# Add spaces
spaced_content = "Card: 5425 2334 3010 9903"
@@ -74,19 +79,19 @@ class TestCreditCardDetection:
assert len(matches) == 1
assert "5425 2334 3010 9903" in matches[0]["value"]
def test_detect_credit_card_with_dashes(self):
def test_detect_credit_card_with_dashes(self) -> None:
content = "Card: 4532-0151-1283-0366"
matches = detect_credit_card(content)
assert len(matches) == 1
def test_invalid_luhn_not_detected(self):
def test_invalid_luhn_not_detected(self) -> None:
# Invalid Luhn checksum
content = "Card: 1234567890123456"
matches = detect_credit_card(content)
assert len(matches) == 0
def test_no_credit_card(self):
def test_no_credit_card(self) -> None:
content = "No cards here."
matches = detect_credit_card(content)
assert len(matches) == 0
@@ -95,7 +100,7 @@ class TestCreditCardDetection:
class TestIPDetection:
"""Test IP address detection."""
def test_detect_valid_ipv4(self):
def test_detect_valid_ipv4(self) -> None:
content = "Server IP: 192.168.1.1"
matches = detect_ip(content)
@@ -103,7 +108,7 @@ class TestIPDetection:
assert matches[0]["type"] == "ip"
assert matches[0]["value"] == "192.168.1.1"
def test_detect_multiple_ips(self):
def test_detect_multiple_ips(self) -> None:
content = "Connect to 10.0.0.1 or 8.8.8.8"
matches = detect_ip(content)
@@ -111,13 +116,13 @@ class TestIPDetection:
assert matches[0]["value"] == "10.0.0.1"
assert matches[1]["value"] == "8.8.8.8"
def test_invalid_ip_not_detected(self):
def test_invalid_ip_not_detected(self) -> None:
# Out of range octets
content = "Not an IP: 999.999.999.999"
matches = detect_ip(content)
assert len(matches) == 0
def test_version_number_not_detected(self):
def test_version_number_not_detected(self) -> None:
# Version numbers should not be detected as IPs
content = "Version 1.2.3.4 released"
matches = detect_ip(content)
@@ -125,7 +130,7 @@ class TestIPDetection:
# This is acceptable behavior
assert len(matches) >= 0
def test_no_ip(self):
def test_no_ip(self) -> None:
content = "No IP addresses here."
matches = detect_ip(content)
assert len(matches) == 0
@@ -134,7 +139,7 @@ class TestIPDetection:
class TestMACAddressDetection:
"""Test MAC address detection."""
def test_detect_mac_with_colons(self):
def test_detect_mac_with_colons(self) -> None:
content = "MAC: 00:1A:2B:3C:4D:5E"
matches = detect_mac_address(content)
@@ -142,26 +147,26 @@ class TestMACAddressDetection:
assert matches[0]["type"] == "mac_address"
assert matches[0]["value"] == "00:1A:2B:3C:4D:5E"
def test_detect_mac_with_dashes(self):
def test_detect_mac_with_dashes(self) -> None:
content = "MAC: 00-1A-2B-3C-4D-5E"
matches = detect_mac_address(content)
assert len(matches) == 1
assert matches[0]["value"] == "00-1A-2B-3C-4D-5E"
def test_detect_lowercase_mac(self):
def test_detect_lowercase_mac(self) -> None:
content = "MAC: aa:bb:cc:dd:ee:ff"
matches = detect_mac_address(content)
assert len(matches) == 1
assert matches[0]["value"] == "aa:bb:cc:dd:ee:ff"
def test_no_mac(self):
def test_no_mac(self) -> None:
content = "No MAC address here."
matches = detect_mac_address(content)
assert len(matches) == 0
def test_partial_mac_not_detected(self):
def test_partial_mac_not_detected(self) -> None:
content = "Partial: 00:1A:2B:3C"
matches = detect_mac_address(content)
assert len(matches) == 0
@@ -170,7 +175,7 @@ class TestMACAddressDetection:
class TestURLDetection:
"""Test URL detection."""
def test_detect_http_url(self):
def test_detect_http_url(self) -> None:
content = "Visit http://example.com for details."
matches = detect_url(content)
@@ -178,39 +183,39 @@ class TestURLDetection:
assert matches[0]["type"] == "url"
assert matches[0]["value"] == "http://example.com"
def test_detect_https_url(self):
def test_detect_https_url(self) -> None:
content = "Visit https://secure.example.com/path"
matches = detect_url(content)
assert len(matches) == 1
assert matches[0]["value"] == "https://secure.example.com/path"
def test_detect_www_url(self):
def test_detect_www_url(self) -> None:
content = "Check www.example.com"
matches = detect_url(content)
assert len(matches) == 1
assert matches[0]["value"] == "www.example.com"
def test_detect_bare_domain_with_path(self):
def test_detect_bare_domain_with_path(self) -> None:
content = "Go to example.com/page"
matches = detect_url(content)
assert len(matches) == 1
assert matches[0]["value"] == "example.com/page"
def test_detect_multiple_urls(self):
def test_detect_multiple_urls(self) -> None:
content = "Visit http://test.com and https://example.org"
matches = detect_url(content)
assert len(matches) == 2
def test_no_url(self):
def test_no_url(self) -> None:
content = "No URLs here."
matches = detect_url(content)
assert len(matches) == 0
def test_bare_domain_without_path_not_detected(self):
def test_bare_domain_without_path_not_detected(self) -> None:
# To reduce false positives, bare domains without paths are not detected
content = "The word example.com in prose"
detect_url(content)
@@ -226,21 +231,21 @@ class TestURLDetection:
class TestRedactStrategy:
"""Test redact strategy."""
def test_redact_email(self):
def test_redact_email(self) -> None:
middleware = PIIMiddleware("email", strategy="redact")
state = {"messages": [HumanMessage("Email me at test@example.com")]}
state = AgentState[Any](messages=[HumanMessage("Email me at test@example.com")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
assert "[REDACTED_EMAIL]" in result["messages"][0].content
assert "test@example.com" not in result["messages"][0].content
def test_redact_multiple_pii(self):
def test_redact_multiple_pii(self) -> None:
middleware = PIIMiddleware("email", strategy="redact")
state = {"messages": [HumanMessage("Contact alice@test.com or bob@test.com")]}
state = AgentState[Any](messages=[HumanMessage("Contact alice@test.com or bob@test.com")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
@@ -252,34 +257,34 @@ class TestRedactStrategy:
class TestMaskStrategy:
"""Test mask strategy."""
def test_mask_email(self):
def test_mask_email(self) -> None:
middleware = PIIMiddleware("email", strategy="mask")
state = {"messages": [HumanMessage("Email: user@example.com")]}
state = AgentState[Any](messages=[HumanMessage("Email: user@example.com")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
assert "user@****.com" in content
assert "user@example.com" not in content
def test_mask_credit_card(self):
def test_mask_credit_card(self) -> None:
middleware = PIIMiddleware("credit_card", strategy="mask")
# Valid test card
state = {"messages": [HumanMessage("Card: 4532015112830366")]}
state = AgentState[Any](messages=[HumanMessage("Card: 4532015112830366")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
assert "0366" in content # Last 4 digits visible
assert "4532015112830366" not in content
def test_mask_ip(self):
def test_mask_ip(self) -> None:
middleware = PIIMiddleware("ip", strategy="mask")
state = {"messages": [HumanMessage("IP: 192.168.1.100")]}
state = AgentState[Any](messages=[HumanMessage("IP: 192.168.1.100")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
@@ -290,11 +295,11 @@ class TestMaskStrategy:
class TestHashStrategy:
"""Test hash strategy."""
def test_hash_email(self):
def test_hash_email(self) -> None:
middleware = PIIMiddleware("email", strategy="hash")
state = {"messages": [HumanMessage("Email: test@example.com")]}
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
@@ -302,39 +307,41 @@ class TestHashStrategy:
assert ">" in content
assert "test@example.com" not in content
def test_hash_is_deterministic(self):
def test_hash_is_deterministic(self) -> None:
middleware = PIIMiddleware("email", strategy="hash")
# Same email should produce same hash
state1 = {"messages": [HumanMessage("Email: test@example.com")]}
state2 = {"messages": [HumanMessage("Email: test@example.com")]}
state1 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
state2 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
result1 = middleware.before_model(state1, None)
result2 = middleware.before_model(state2, None)
result1 = middleware.before_model(state1, Runtime())
result2 = middleware.before_model(state2, Runtime())
assert result1 is not None
assert result2 is not None
assert result1["messages"][0].content == result2["messages"][0].content
class TestBlockStrategy:
"""Test block strategy."""
def test_block_raises_exception(self):
def test_block_raises_exception(self) -> None:
middleware = PIIMiddleware("email", strategy="block")
state = {"messages": [HumanMessage("Email: test@example.com")]}
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
with pytest.raises(PIIDetectionError) as exc_info:
middleware.before_model(state, None)
middleware.before_model(state, Runtime())
assert exc_info.value.pii_type == "email"
assert len(exc_info.value.matches) == 1
assert "test@example.com" in exc_info.value.matches[0]["value"]
def test_block_with_multiple_matches(self):
def test_block_with_multiple_matches(self) -> None:
middleware = PIIMiddleware("email", strategy="block")
state = {"messages": [HumanMessage("Emails: alice@test.com and bob@test.com")]}
state = AgentState[Any](messages=[HumanMessage("Emails: alice@test.com and bob@test.com")])
with pytest.raises(PIIDetectionError) as exc_info:
middleware.before_model(state, None)
middleware.before_model(state, Runtime())
assert len(exc_info.value.matches) == 2
@@ -347,81 +354,81 @@ class TestBlockStrategy:
class TestPIIMiddlewareIntegration:
"""Test PIIMiddleware integration with agent."""
def test_apply_to_input_only(self):
def test_apply_to_input_only(self) -> None:
"""Test that middleware only processes input when configured."""
middleware = PIIMiddleware(
"email", strategy="redact", apply_to_input=True, apply_to_output=False
)
# Should process HumanMessage
state = {"messages": [HumanMessage("Email: test@example.com")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
result = middleware.before_model(state, Runtime())
assert result is not None
assert "[REDACTED_EMAIL]" in result["messages"][0].content
# Should not process AIMessage
state = {"messages": [AIMessage("My email is ai@example.com")]}
result = middleware.after_model(state, None)
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
result = middleware.after_model(state, Runtime())
assert result is None
def test_apply_to_output_only(self):
def test_apply_to_output_only(self) -> None:
"""Test that middleware only processes output when configured."""
middleware = PIIMiddleware(
"email", strategy="redact", apply_to_input=False, apply_to_output=True
)
# Should not process HumanMessage
state = {"messages": [HumanMessage("Email: test@example.com")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
result = middleware.before_model(state, Runtime())
assert result is None
# Should process AIMessage
state = {"messages": [AIMessage("My email is ai@example.com")]}
result = middleware.after_model(state, None)
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
result = middleware.after_model(state, Runtime())
assert result is not None
assert "[REDACTED_EMAIL]" in result["messages"][0].content
def test_apply_to_both(self):
def test_apply_to_both(self) -> None:
"""Test that middleware processes both input and output."""
middleware = PIIMiddleware(
"email", strategy="redact", apply_to_input=True, apply_to_output=True
)
# Should process HumanMessage
state = {"messages": [HumanMessage("Email: test@example.com")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
result = middleware.before_model(state, Runtime())
assert result is not None
# Should process AIMessage
state = {"messages": [AIMessage("My email is ai@example.com")]}
result = middleware.after_model(state, None)
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
result = middleware.after_model(state, Runtime())
assert result is not None
def test_no_pii_returns_none(self):
def test_no_pii_returns_none(self) -> None:
"""Test that middleware returns None when no PII detected."""
middleware = PIIMiddleware("email", strategy="redact")
state = {"messages": [HumanMessage("No PII here")]}
state = AgentState[Any](messages=[HumanMessage("No PII here")])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is None
def test_empty_messages(self):
def test_empty_messages(self) -> None:
"""Test that middleware handles empty messages gracefully."""
middleware = PIIMiddleware("email", strategy="redact")
state = {"messages": []}
state = AgentState[Any](messages=[])
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is None
def test_apply_to_tool_results(self):
def test_apply_to_tool_results(self) -> None:
"""Test that middleware processes tool results when enabled."""
middleware = PIIMiddleware(
"email", strategy="redact", apply_to_input=False, apply_to_tool_results=True
)
# Simulate a conversation with tool call and result containing PII
state = {
"messages": [
state = AgentState[Any](
messages=[
HumanMessage("Search for John"),
AIMessage(
content="",
@@ -429,9 +436,9 @@ class TestPIIMiddlewareIntegration:
),
ToolMessage(content="Found: john@example.com", tool_call_id="call_123"),
]
}
)
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
# Check that the tool message was redacted
@@ -440,14 +447,14 @@ class TestPIIMiddlewareIntegration:
assert "[REDACTED_EMAIL]" in tool_msg.content
assert "john@example.com" not in tool_msg.content
def test_apply_to_tool_results_mask_strategy(self):
def test_apply_to_tool_results_mask_strategy(self) -> None:
"""Test that mask strategy works for tool results."""
middleware = PIIMiddleware(
"ip", strategy="mask", apply_to_input=False, apply_to_tool_results=True
)
state = {
"messages": [
state = AgentState[Any](
messages=[
HumanMessage("Get server IP"),
AIMessage(
content="",
@@ -455,23 +462,23 @@ class TestPIIMiddlewareIntegration:
),
ToolMessage(content="Server IP: 192.168.1.100", tool_call_id="call_456"),
]
}
)
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
tool_msg = result["messages"][2]
assert "*.*.*.100" in tool_msg.content
assert "192.168.1.100" not in tool_msg.content
def test_apply_to_tool_results_block_strategy(self):
def test_apply_to_tool_results_block_strategy(self) -> None:
"""Test that block strategy raises error for PII in tool results."""
middleware = PIIMiddleware(
"email", strategy="block", apply_to_input=False, apply_to_tool_results=True
)
state = {
"messages": [
state = AgentState[Any](
messages=[
HumanMessage("Search for user"),
AIMessage(
content="",
@@ -479,17 +486,17 @@ class TestPIIMiddlewareIntegration:
),
ToolMessage(content="User email: sensitive@example.com", tool_call_id="call_789"),
]
}
)
with pytest.raises(PIIDetectionError) as exc_info:
middleware.before_model(state, None)
middleware.before_model(state, Runtime())
assert exc_info.value.pii_type == "email"
assert len(exc_info.value.matches) == 1
def test_with_agent(self):
def test_with_agent(self) -> None:
"""Test PIIMiddleware integrated with create_agent."""
model = FakeToolCallingModel(responses=[AIMessage(content="Thanks for sharing!")])
model = FakeToolCallingModel()
agent = create_agent(
model=model,
@@ -508,7 +515,7 @@ class TestPIIMiddlewareIntegration:
class TestCustomDetector:
"""Test custom detector functionality."""
def test_custom_regex_detector(self):
def test_custom_regex_detector(self) -> None:
# Custom regex for API keys
middleware = PIIMiddleware(
"api_key",
@@ -516,25 +523,25 @@ class TestCustomDetector:
strategy="redact",
)
state = {"messages": [HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")])
result = middleware.before_model(state, Runtime())
assert result is not None
assert "[REDACTED_API_KEY]" in result["messages"][0].content
def test_custom_callable_detector(self):
def test_custom_callable_detector(self) -> None:
# Custom detector function
def detect_custom(content):
def detect_custom(content: str) -> list[PIIMatch]:
matches = []
if "CONFIDENTIAL" in content:
idx = content.index("CONFIDENTIAL")
matches.append(
{
"type": "confidential",
"value": "CONFIDENTIAL",
"start": idx,
"end": idx + 12,
}
PIIMatch(
type="confidential",
value="CONFIDENTIAL",
start=idx,
end=idx + 12,
)
)
return matches
@@ -544,17 +551,17 @@ class TestCustomDetector:
strategy="redact",
)
state = {"messages": [HumanMessage("This is CONFIDENTIAL information")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("This is CONFIDENTIAL information")])
result = middleware.before_model(state, Runtime())
assert result is not None
assert "[REDACTED_CONFIDENTIAL]" in result["messages"][0].content
def test_unknown_builtin_type_raises_error(self):
def test_unknown_builtin_type_raises_error(self) -> None:
with pytest.raises(ValueError, match="Unknown PII type"):
PIIMiddleware("unknown_type", strategy="redact")
def test_custom_type_without_detector_raises_error(self):
def test_custom_type_without_detector_raises_error(self) -> None:
with pytest.raises(ValueError, match="Unknown PII type"):
PIIMiddleware("custom_type", strategy="redact")
@@ -562,18 +569,20 @@ class TestCustomDetector:
class TestMultipleMiddleware:
"""Test using multiple PII middleware instances."""
def test_sequential_application(self):
def test_sequential_application(self) -> None:
"""Test that multiple PII types are detected when applied sequentially."""
# First apply email middleware
email_middleware = PIIMiddleware("email", strategy="redact")
state = {"messages": [HumanMessage("Email: test@example.com, IP: 192.168.1.1")]}
result1 = email_middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 192.168.1.1")])
result1 = email_middleware.before_model(state, Runtime())
# Then apply IP middleware to the result
ip_middleware = PIIMiddleware("ip", strategy="mask")
state_with_email_redacted = {"messages": result1["messages"]}
result2 = ip_middleware.before_model(state_with_email_redacted, None)
assert result1 is not None
state_with_email_redacted = AgentState[Any](messages=result1["messages"])
result2 = ip_middleware.before_model(state_with_email_redacted, Runtime())
assert result2 is not None
content = result2["messages"][0].content
# Email should be redacted
@@ -584,9 +593,9 @@ class TestMultipleMiddleware:
assert "*.*.*.1" in content
assert "192.168.1.1" not in content
def test_multiple_pii_middleware_with_create_agent(self):
def test_multiple_pii_middleware_with_create_agent(self) -> None:
"""Test that multiple PIIMiddleware instances work together in create_agent."""
model = FakeToolCallingModel(responses=[AIMessage(content="Response received")])
model = FakeToolCallingModel()
# Multiple PIIMiddleware instances should work because each has a unique name
agent = create_agent(
@@ -611,7 +620,7 @@ class TestMultipleMiddleware:
# IP should be masked
assert "192.168.1.100" not in content
def test_custom_detector_for_multiple_types(self):
def test_custom_detector_for_multiple_types(self) -> None:
"""Test using a single middleware with custom detector for multiple PII types.
This is an alternative to using multiple middleware instances,
@@ -619,7 +628,7 @@ class TestMultipleMiddleware:
"""
# Combine multiple detectors into one
def detect_email_and_ip(content):
def detect_email_and_ip(content: str) -> list[PIIMatch]:
return detect_email(content) + detect_ip(content)
middleware = PIIMiddleware(
@@ -628,9 +637,10 @@ class TestMultipleMiddleware:
strategy="redact",
)
state = {"messages": [HumanMessage("Email: test@example.com, IP: 10.0.0.1")]}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 10.0.0.1")])
result = middleware.before_model(state, Runtime())
assert result is not None
content = result["messages"][0].content
assert "test@example.com" not in content
assert "10.0.0.1" not in content