mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
chore(langchain): fix types in test_pii (#34617)
This commit is contained in:
committed by
GitHub
parent
6537939f53
commit
7f4f130479
@@ -360,6 +360,7 @@ class PIIMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PIIDetectionError",
|
"PIIDetectionError",
|
||||||
|
"PIIMatch",
|
||||||
"PIIMiddleware",
|
"PIIMiddleware",
|
||||||
"detect_credit_card",
|
"detect_credit_card",
|
||||||
"detect_email",
|
"detect_email",
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
"""Tests for PII detection middleware."""
|
"""Tests for PII detection middleware."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.factory import create_agent
|
from langchain.agents.factory import create_agent
|
||||||
from langchain.agents.middleware.pii import (
|
from langchain.agents.middleware.pii import (
|
||||||
PIIDetectionError,
|
PIIDetectionError,
|
||||||
|
PIIMatch,
|
||||||
PIIMiddleware,
|
PIIMiddleware,
|
||||||
detect_credit_card,
|
detect_credit_card,
|
||||||
detect_email,
|
detect_email,
|
||||||
@@ -23,7 +28,7 @@ from tests.unit_tests.agents.model import FakeToolCallingModel
|
|||||||
class TestEmailDetection:
|
class TestEmailDetection:
|
||||||
"""Test email detection."""
|
"""Test email detection."""
|
||||||
|
|
||||||
def test_detect_valid_email(self):
|
def test_detect_valid_email(self) -> None:
|
||||||
content = "Contact me at john.doe@example.com for more info."
|
content = "Contact me at john.doe@example.com for more info."
|
||||||
matches = detect_email(content)
|
matches = detect_email(content)
|
||||||
|
|
||||||
@@ -33,7 +38,7 @@ class TestEmailDetection:
|
|||||||
assert matches[0]["start"] == 14
|
assert matches[0]["start"] == 14
|
||||||
assert matches[0]["end"] == 34
|
assert matches[0]["end"] == 34
|
||||||
|
|
||||||
def test_detect_multiple_emails(self):
|
def test_detect_multiple_emails(self) -> None:
|
||||||
content = "Email alice@test.com or bob@company.org"
|
content = "Email alice@test.com or bob@company.org"
|
||||||
matches = detect_email(content)
|
matches = detect_email(content)
|
||||||
|
|
||||||
@@ -41,12 +46,12 @@ class TestEmailDetection:
|
|||||||
assert matches[0]["value"] == "alice@test.com"
|
assert matches[0]["value"] == "alice@test.com"
|
||||||
assert matches[1]["value"] == "bob@company.org"
|
assert matches[1]["value"] == "bob@company.org"
|
||||||
|
|
||||||
def test_no_email(self):
|
def test_no_email(self) -> None:
|
||||||
content = "This text has no email addresses."
|
content = "This text has no email addresses."
|
||||||
matches = detect_email(content)
|
matches = detect_email(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
def test_invalid_email_format(self):
|
def test_invalid_email_format(self) -> None:
|
||||||
content = "Invalid emails: @test.com, user@, user@domain"
|
content = "Invalid emails: @test.com, user@, user@domain"
|
||||||
matches = detect_email(content)
|
matches = detect_email(content)
|
||||||
# Should not match invalid formats
|
# Should not match invalid formats
|
||||||
@@ -56,7 +61,7 @@ class TestEmailDetection:
|
|||||||
class TestCreditCardDetection:
|
class TestCreditCardDetection:
|
||||||
"""Test credit card detection with Luhn validation."""
|
"""Test credit card detection with Luhn validation."""
|
||||||
|
|
||||||
def test_detect_valid_credit_card(self):
|
def test_detect_valid_credit_card(self) -> None:
|
||||||
# Valid Visa test number
|
# Valid Visa test number
|
||||||
content = "Card: 4532015112830366"
|
content = "Card: 4532015112830366"
|
||||||
matches = detect_credit_card(content)
|
matches = detect_credit_card(content)
|
||||||
@@ -65,7 +70,7 @@ class TestCreditCardDetection:
|
|||||||
assert matches[0]["type"] == "credit_card"
|
assert matches[0]["type"] == "credit_card"
|
||||||
assert matches[0]["value"] == "4532015112830366"
|
assert matches[0]["value"] == "4532015112830366"
|
||||||
|
|
||||||
def test_detect_credit_card_with_spaces(self):
|
def test_detect_credit_card_with_spaces(self) -> None:
|
||||||
# Valid Mastercard test number
|
# Valid Mastercard test number
|
||||||
# Add spaces
|
# Add spaces
|
||||||
spaced_content = "Card: 5425 2334 3010 9903"
|
spaced_content = "Card: 5425 2334 3010 9903"
|
||||||
@@ -74,19 +79,19 @@ class TestCreditCardDetection:
|
|||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert "5425 2334 3010 9903" in matches[0]["value"]
|
assert "5425 2334 3010 9903" in matches[0]["value"]
|
||||||
|
|
||||||
def test_detect_credit_card_with_dashes(self):
|
def test_detect_credit_card_with_dashes(self) -> None:
|
||||||
content = "Card: 4532-0151-1283-0366"
|
content = "Card: 4532-0151-1283-0366"
|
||||||
matches = detect_credit_card(content)
|
matches = detect_credit_card(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
|
||||||
def test_invalid_luhn_not_detected(self):
|
def test_invalid_luhn_not_detected(self) -> None:
|
||||||
# Invalid Luhn checksum
|
# Invalid Luhn checksum
|
||||||
content = "Card: 1234567890123456"
|
content = "Card: 1234567890123456"
|
||||||
matches = detect_credit_card(content)
|
matches = detect_credit_card(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
def test_no_credit_card(self):
|
def test_no_credit_card(self) -> None:
|
||||||
content = "No cards here."
|
content = "No cards here."
|
||||||
matches = detect_credit_card(content)
|
matches = detect_credit_card(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
@@ -95,7 +100,7 @@ class TestCreditCardDetection:
|
|||||||
class TestIPDetection:
|
class TestIPDetection:
|
||||||
"""Test IP address detection."""
|
"""Test IP address detection."""
|
||||||
|
|
||||||
def test_detect_valid_ipv4(self):
|
def test_detect_valid_ipv4(self) -> None:
|
||||||
content = "Server IP: 192.168.1.1"
|
content = "Server IP: 192.168.1.1"
|
||||||
matches = detect_ip(content)
|
matches = detect_ip(content)
|
||||||
|
|
||||||
@@ -103,7 +108,7 @@ class TestIPDetection:
|
|||||||
assert matches[0]["type"] == "ip"
|
assert matches[0]["type"] == "ip"
|
||||||
assert matches[0]["value"] == "192.168.1.1"
|
assert matches[0]["value"] == "192.168.1.1"
|
||||||
|
|
||||||
def test_detect_multiple_ips(self):
|
def test_detect_multiple_ips(self) -> None:
|
||||||
content = "Connect to 10.0.0.1 or 8.8.8.8"
|
content = "Connect to 10.0.0.1 or 8.8.8.8"
|
||||||
matches = detect_ip(content)
|
matches = detect_ip(content)
|
||||||
|
|
||||||
@@ -111,13 +116,13 @@ class TestIPDetection:
|
|||||||
assert matches[0]["value"] == "10.0.0.1"
|
assert matches[0]["value"] == "10.0.0.1"
|
||||||
assert matches[1]["value"] == "8.8.8.8"
|
assert matches[1]["value"] == "8.8.8.8"
|
||||||
|
|
||||||
def test_invalid_ip_not_detected(self):
|
def test_invalid_ip_not_detected(self) -> None:
|
||||||
# Out of range octets
|
# Out of range octets
|
||||||
content = "Not an IP: 999.999.999.999"
|
content = "Not an IP: 999.999.999.999"
|
||||||
matches = detect_ip(content)
|
matches = detect_ip(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
def test_version_number_not_detected(self):
|
def test_version_number_not_detected(self) -> None:
|
||||||
# Version numbers should not be detected as IPs
|
# Version numbers should not be detected as IPs
|
||||||
content = "Version 1.2.3.4 released"
|
content = "Version 1.2.3.4 released"
|
||||||
matches = detect_ip(content)
|
matches = detect_ip(content)
|
||||||
@@ -125,7 +130,7 @@ class TestIPDetection:
|
|||||||
# This is acceptable behavior
|
# This is acceptable behavior
|
||||||
assert len(matches) >= 0
|
assert len(matches) >= 0
|
||||||
|
|
||||||
def test_no_ip(self):
|
def test_no_ip(self) -> None:
|
||||||
content = "No IP addresses here."
|
content = "No IP addresses here."
|
||||||
matches = detect_ip(content)
|
matches = detect_ip(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
@@ -134,7 +139,7 @@ class TestIPDetection:
|
|||||||
class TestMACAddressDetection:
|
class TestMACAddressDetection:
|
||||||
"""Test MAC address detection."""
|
"""Test MAC address detection."""
|
||||||
|
|
||||||
def test_detect_mac_with_colons(self):
|
def test_detect_mac_with_colons(self) -> None:
|
||||||
content = "MAC: 00:1A:2B:3C:4D:5E"
|
content = "MAC: 00:1A:2B:3C:4D:5E"
|
||||||
matches = detect_mac_address(content)
|
matches = detect_mac_address(content)
|
||||||
|
|
||||||
@@ -142,26 +147,26 @@ class TestMACAddressDetection:
|
|||||||
assert matches[0]["type"] == "mac_address"
|
assert matches[0]["type"] == "mac_address"
|
||||||
assert matches[0]["value"] == "00:1A:2B:3C:4D:5E"
|
assert matches[0]["value"] == "00:1A:2B:3C:4D:5E"
|
||||||
|
|
||||||
def test_detect_mac_with_dashes(self):
|
def test_detect_mac_with_dashes(self) -> None:
|
||||||
content = "MAC: 00-1A-2B-3C-4D-5E"
|
content = "MAC: 00-1A-2B-3C-4D-5E"
|
||||||
matches = detect_mac_address(content)
|
matches = detect_mac_address(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0]["value"] == "00-1A-2B-3C-4D-5E"
|
assert matches[0]["value"] == "00-1A-2B-3C-4D-5E"
|
||||||
|
|
||||||
def test_detect_lowercase_mac(self):
|
def test_detect_lowercase_mac(self) -> None:
|
||||||
content = "MAC: aa:bb:cc:dd:ee:ff"
|
content = "MAC: aa:bb:cc:dd:ee:ff"
|
||||||
matches = detect_mac_address(content)
|
matches = detect_mac_address(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0]["value"] == "aa:bb:cc:dd:ee:ff"
|
assert matches[0]["value"] == "aa:bb:cc:dd:ee:ff"
|
||||||
|
|
||||||
def test_no_mac(self):
|
def test_no_mac(self) -> None:
|
||||||
content = "No MAC address here."
|
content = "No MAC address here."
|
||||||
matches = detect_mac_address(content)
|
matches = detect_mac_address(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
def test_partial_mac_not_detected(self):
|
def test_partial_mac_not_detected(self) -> None:
|
||||||
content = "Partial: 00:1A:2B:3C"
|
content = "Partial: 00:1A:2B:3C"
|
||||||
matches = detect_mac_address(content)
|
matches = detect_mac_address(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
@@ -170,7 +175,7 @@ class TestMACAddressDetection:
|
|||||||
class TestURLDetection:
|
class TestURLDetection:
|
||||||
"""Test URL detection."""
|
"""Test URL detection."""
|
||||||
|
|
||||||
def test_detect_http_url(self):
|
def test_detect_http_url(self) -> None:
|
||||||
content = "Visit http://example.com for details."
|
content = "Visit http://example.com for details."
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
|
|
||||||
@@ -178,39 +183,39 @@ class TestURLDetection:
|
|||||||
assert matches[0]["type"] == "url"
|
assert matches[0]["type"] == "url"
|
||||||
assert matches[0]["value"] == "http://example.com"
|
assert matches[0]["value"] == "http://example.com"
|
||||||
|
|
||||||
def test_detect_https_url(self):
|
def test_detect_https_url(self) -> None:
|
||||||
content = "Visit https://secure.example.com/path"
|
content = "Visit https://secure.example.com/path"
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0]["value"] == "https://secure.example.com/path"
|
assert matches[0]["value"] == "https://secure.example.com/path"
|
||||||
|
|
||||||
def test_detect_www_url(self):
|
def test_detect_www_url(self) -> None:
|
||||||
content = "Check www.example.com"
|
content = "Check www.example.com"
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0]["value"] == "www.example.com"
|
assert matches[0]["value"] == "www.example.com"
|
||||||
|
|
||||||
def test_detect_bare_domain_with_path(self):
|
def test_detect_bare_domain_with_path(self) -> None:
|
||||||
content = "Go to example.com/page"
|
content = "Go to example.com/page"
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
|
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0]["value"] == "example.com/page"
|
assert matches[0]["value"] == "example.com/page"
|
||||||
|
|
||||||
def test_detect_multiple_urls(self):
|
def test_detect_multiple_urls(self) -> None:
|
||||||
content = "Visit http://test.com and https://example.org"
|
content = "Visit http://test.com and https://example.org"
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
|
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
def test_no_url(self):
|
def test_no_url(self) -> None:
|
||||||
content = "No URLs here."
|
content = "No URLs here."
|
||||||
matches = detect_url(content)
|
matches = detect_url(content)
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
def test_bare_domain_without_path_not_detected(self):
|
def test_bare_domain_without_path_not_detected(self) -> None:
|
||||||
# To reduce false positives, bare domains without paths are not detected
|
# To reduce false positives, bare domains without paths are not detected
|
||||||
content = "The word example.com in prose"
|
content = "The word example.com in prose"
|
||||||
detect_url(content)
|
detect_url(content)
|
||||||
@@ -226,21 +231,21 @@ class TestURLDetection:
|
|||||||
class TestRedactStrategy:
|
class TestRedactStrategy:
|
||||||
"""Test redact strategy."""
|
"""Test redact strategy."""
|
||||||
|
|
||||||
def test_redact_email(self):
|
def test_redact_email(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="redact")
|
middleware = PIIMiddleware("email", strategy="redact")
|
||||||
state = {"messages": [HumanMessage("Email me at test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email me at test@example.com")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
||||||
assert "test@example.com" not in result["messages"][0].content
|
assert "test@example.com" not in result["messages"][0].content
|
||||||
|
|
||||||
def test_redact_multiple_pii(self):
|
def test_redact_multiple_pii(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="redact")
|
middleware = PIIMiddleware("email", strategy="redact")
|
||||||
state = {"messages": [HumanMessage("Contact alice@test.com or bob@test.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Contact alice@test.com or bob@test.com")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
@@ -252,34 +257,34 @@ class TestRedactStrategy:
|
|||||||
class TestMaskStrategy:
|
class TestMaskStrategy:
|
||||||
"""Test mask strategy."""
|
"""Test mask strategy."""
|
||||||
|
|
||||||
def test_mask_email(self):
|
def test_mask_email(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="mask")
|
middleware = PIIMiddleware("email", strategy="mask")
|
||||||
state = {"messages": [HumanMessage("Email: user@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: user@example.com")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
assert "user@****.com" in content
|
assert "user@****.com" in content
|
||||||
assert "user@example.com" not in content
|
assert "user@example.com" not in content
|
||||||
|
|
||||||
def test_mask_credit_card(self):
|
def test_mask_credit_card(self) -> None:
|
||||||
middleware = PIIMiddleware("credit_card", strategy="mask")
|
middleware = PIIMiddleware("credit_card", strategy="mask")
|
||||||
# Valid test card
|
# Valid test card
|
||||||
state = {"messages": [HumanMessage("Card: 4532015112830366")]}
|
state = AgentState[Any](messages=[HumanMessage("Card: 4532015112830366")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
assert "0366" in content # Last 4 digits visible
|
assert "0366" in content # Last 4 digits visible
|
||||||
assert "4532015112830366" not in content
|
assert "4532015112830366" not in content
|
||||||
|
|
||||||
def test_mask_ip(self):
|
def test_mask_ip(self) -> None:
|
||||||
middleware = PIIMiddleware("ip", strategy="mask")
|
middleware = PIIMiddleware("ip", strategy="mask")
|
||||||
state = {"messages": [HumanMessage("IP: 192.168.1.100")]}
|
state = AgentState[Any](messages=[HumanMessage("IP: 192.168.1.100")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
@@ -290,11 +295,11 @@ class TestMaskStrategy:
|
|||||||
class TestHashStrategy:
|
class TestHashStrategy:
|
||||||
"""Test hash strategy."""
|
"""Test hash strategy."""
|
||||||
|
|
||||||
def test_hash_email(self):
|
def test_hash_email(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="hash")
|
middleware = PIIMiddleware("email", strategy="hash")
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
@@ -302,39 +307,41 @@ class TestHashStrategy:
|
|||||||
assert ">" in content
|
assert ">" in content
|
||||||
assert "test@example.com" not in content
|
assert "test@example.com" not in content
|
||||||
|
|
||||||
def test_hash_is_deterministic(self):
|
def test_hash_is_deterministic(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="hash")
|
middleware = PIIMiddleware("email", strategy="hash")
|
||||||
|
|
||||||
# Same email should produce same hash
|
# Same email should produce same hash
|
||||||
state1 = {"messages": [HumanMessage("Email: test@example.com")]}
|
state1 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
state2 = {"messages": [HumanMessage("Email: test@example.com")]}
|
state2 = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
|
|
||||||
result1 = middleware.before_model(state1, None)
|
result1 = middleware.before_model(state1, Runtime())
|
||||||
result2 = middleware.before_model(state2, None)
|
result2 = middleware.before_model(state2, Runtime())
|
||||||
|
|
||||||
|
assert result1 is not None
|
||||||
|
assert result2 is not None
|
||||||
assert result1["messages"][0].content == result2["messages"][0].content
|
assert result1["messages"][0].content == result2["messages"][0].content
|
||||||
|
|
||||||
|
|
||||||
class TestBlockStrategy:
|
class TestBlockStrategy:
|
||||||
"""Test block strategy."""
|
"""Test block strategy."""
|
||||||
|
|
||||||
def test_block_raises_exception(self):
|
def test_block_raises_exception(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="block")
|
middleware = PIIMiddleware("email", strategy="block")
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
|
|
||||||
with pytest.raises(PIIDetectionError) as exc_info:
|
with pytest.raises(PIIDetectionError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert exc_info.value.pii_type == "email"
|
assert exc_info.value.pii_type == "email"
|
||||||
assert len(exc_info.value.matches) == 1
|
assert len(exc_info.value.matches) == 1
|
||||||
assert "test@example.com" in exc_info.value.matches[0]["value"]
|
assert "test@example.com" in exc_info.value.matches[0]["value"]
|
||||||
|
|
||||||
def test_block_with_multiple_matches(self):
|
def test_block_with_multiple_matches(self) -> None:
|
||||||
middleware = PIIMiddleware("email", strategy="block")
|
middleware = PIIMiddleware("email", strategy="block")
|
||||||
state = {"messages": [HumanMessage("Emails: alice@test.com and bob@test.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Emails: alice@test.com and bob@test.com")])
|
||||||
|
|
||||||
with pytest.raises(PIIDetectionError) as exc_info:
|
with pytest.raises(PIIDetectionError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert len(exc_info.value.matches) == 2
|
assert len(exc_info.value.matches) == 2
|
||||||
|
|
||||||
@@ -347,81 +354,81 @@ class TestBlockStrategy:
|
|||||||
class TestPIIMiddlewareIntegration:
|
class TestPIIMiddlewareIntegration:
|
||||||
"""Test PIIMiddleware integration with agent."""
|
"""Test PIIMiddleware integration with agent."""
|
||||||
|
|
||||||
def test_apply_to_input_only(self):
|
def test_apply_to_input_only(self) -> None:
|
||||||
"""Test that middleware only processes input when configured."""
|
"""Test that middleware only processes input when configured."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"email", strategy="redact", apply_to_input=True, apply_to_output=False
|
"email", strategy="redact", apply_to_input=True, apply_to_output=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should process HumanMessage
|
# Should process HumanMessage
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
||||||
|
|
||||||
# Should not process AIMessage
|
# Should not process AIMessage
|
||||||
state = {"messages": [AIMessage("My email is ai@example.com")]}
|
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
|
||||||
result = middleware.after_model(state, None)
|
result = middleware.after_model(state, Runtime())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_apply_to_output_only(self):
|
def test_apply_to_output_only(self) -> None:
|
||||||
"""Test that middleware only processes output when configured."""
|
"""Test that middleware only processes output when configured."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"email", strategy="redact", apply_to_input=False, apply_to_output=True
|
"email", strategy="redact", apply_to_input=False, apply_to_output=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not process HumanMessage
|
# Should not process HumanMessage
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
# Should process AIMessage
|
# Should process AIMessage
|
||||||
state = {"messages": [AIMessage("My email is ai@example.com")]}
|
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
|
||||||
result = middleware.after_model(state, None)
|
result = middleware.after_model(state, Runtime())
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
assert "[REDACTED_EMAIL]" in result["messages"][0].content
|
||||||
|
|
||||||
def test_apply_to_both(self):
|
def test_apply_to_both(self) -> None:
|
||||||
"""Test that middleware processes both input and output."""
|
"""Test that middleware processes both input and output."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"email", strategy="redact", apply_to_input=True, apply_to_output=True
|
"email", strategy="redact", apply_to_input=True, apply_to_output=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should process HumanMessage
|
# Should process HumanMessage
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
# Should process AIMessage
|
# Should process AIMessage
|
||||||
state = {"messages": [AIMessage("My email is ai@example.com")]}
|
state = AgentState[Any](messages=[AIMessage("My email is ai@example.com")])
|
||||||
result = middleware.after_model(state, None)
|
result = middleware.after_model(state, Runtime())
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
def test_no_pii_returns_none(self):
|
def test_no_pii_returns_none(self) -> None:
|
||||||
"""Test that middleware returns None when no PII detected."""
|
"""Test that middleware returns None when no PII detected."""
|
||||||
middleware = PIIMiddleware("email", strategy="redact")
|
middleware = PIIMiddleware("email", strategy="redact")
|
||||||
state = {"messages": [HumanMessage("No PII here")]}
|
state = AgentState[Any](messages=[HumanMessage("No PII here")])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_empty_messages(self):
|
def test_empty_messages(self) -> None:
|
||||||
"""Test that middleware handles empty messages gracefully."""
|
"""Test that middleware handles empty messages gracefully."""
|
||||||
middleware = PIIMiddleware("email", strategy="redact")
|
middleware = PIIMiddleware("email", strategy="redact")
|
||||||
state = {"messages": []}
|
state = AgentState[Any](messages=[])
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
def test_apply_to_tool_results(self):
|
def test_apply_to_tool_results(self) -> None:
|
||||||
"""Test that middleware processes tool results when enabled."""
|
"""Test that middleware processes tool results when enabled."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"email", strategy="redact", apply_to_input=False, apply_to_tool_results=True
|
"email", strategy="redact", apply_to_input=False, apply_to_tool_results=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simulate a conversation with tool call and result containing PII
|
# Simulate a conversation with tool call and result containing PII
|
||||||
state = {
|
state = AgentState[Any](
|
||||||
"messages": [
|
messages=[
|
||||||
HumanMessage("Search for John"),
|
HumanMessage("Search for John"),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="",
|
content="",
|
||||||
@@ -429,9 +436,9 @@ class TestPIIMiddlewareIntegration:
|
|||||||
),
|
),
|
||||||
ToolMessage(content="Found: john@example.com", tool_call_id="call_123"),
|
ToolMessage(content="Found: john@example.com", tool_call_id="call_123"),
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
# Check that the tool message was redacted
|
# Check that the tool message was redacted
|
||||||
@@ -440,14 +447,14 @@ class TestPIIMiddlewareIntegration:
|
|||||||
assert "[REDACTED_EMAIL]" in tool_msg.content
|
assert "[REDACTED_EMAIL]" in tool_msg.content
|
||||||
assert "john@example.com" not in tool_msg.content
|
assert "john@example.com" not in tool_msg.content
|
||||||
|
|
||||||
def test_apply_to_tool_results_mask_strategy(self):
|
def test_apply_to_tool_results_mask_strategy(self) -> None:
|
||||||
"""Test that mask strategy works for tool results."""
|
"""Test that mask strategy works for tool results."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"ip", strategy="mask", apply_to_input=False, apply_to_tool_results=True
|
"ip", strategy="mask", apply_to_input=False, apply_to_tool_results=True
|
||||||
)
|
)
|
||||||
|
|
||||||
state = {
|
state = AgentState[Any](
|
||||||
"messages": [
|
messages=[
|
||||||
HumanMessage("Get server IP"),
|
HumanMessage("Get server IP"),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="",
|
content="",
|
||||||
@@ -455,23 +462,23 @@ class TestPIIMiddlewareIntegration:
|
|||||||
),
|
),
|
||||||
ToolMessage(content="Server IP: 192.168.1.100", tool_call_id="call_456"),
|
ToolMessage(content="Server IP: 192.168.1.100", tool_call_id="call_456"),
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
tool_msg = result["messages"][2]
|
tool_msg = result["messages"][2]
|
||||||
assert "*.*.*.100" in tool_msg.content
|
assert "*.*.*.100" in tool_msg.content
|
||||||
assert "192.168.1.100" not in tool_msg.content
|
assert "192.168.1.100" not in tool_msg.content
|
||||||
|
|
||||||
def test_apply_to_tool_results_block_strategy(self):
|
def test_apply_to_tool_results_block_strategy(self) -> None:
|
||||||
"""Test that block strategy raises error for PII in tool results."""
|
"""Test that block strategy raises error for PII in tool results."""
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"email", strategy="block", apply_to_input=False, apply_to_tool_results=True
|
"email", strategy="block", apply_to_input=False, apply_to_tool_results=True
|
||||||
)
|
)
|
||||||
|
|
||||||
state = {
|
state = AgentState[Any](
|
||||||
"messages": [
|
messages=[
|
||||||
HumanMessage("Search for user"),
|
HumanMessage("Search for user"),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
content="",
|
content="",
|
||||||
@@ -479,17 +486,17 @@ class TestPIIMiddlewareIntegration:
|
|||||||
),
|
),
|
||||||
ToolMessage(content="User email: sensitive@example.com", tool_call_id="call_789"),
|
ToolMessage(content="User email: sensitive@example.com", tool_call_id="call_789"),
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
|
||||||
with pytest.raises(PIIDetectionError) as exc_info:
|
with pytest.raises(PIIDetectionError) as exc_info:
|
||||||
middleware.before_model(state, None)
|
middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert exc_info.value.pii_type == "email"
|
assert exc_info.value.pii_type == "email"
|
||||||
assert len(exc_info.value.matches) == 1
|
assert len(exc_info.value.matches) == 1
|
||||||
|
|
||||||
def test_with_agent(self):
|
def test_with_agent(self) -> None:
|
||||||
"""Test PIIMiddleware integrated with create_agent."""
|
"""Test PIIMiddleware integrated with create_agent."""
|
||||||
model = FakeToolCallingModel(responses=[AIMessage(content="Thanks for sharing!")])
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -508,7 +515,7 @@ class TestPIIMiddlewareIntegration:
|
|||||||
class TestCustomDetector:
|
class TestCustomDetector:
|
||||||
"""Test custom detector functionality."""
|
"""Test custom detector functionality."""
|
||||||
|
|
||||||
def test_custom_regex_detector(self):
|
def test_custom_regex_detector(self) -> None:
|
||||||
# Custom regex for API keys
|
# Custom regex for API keys
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
"api_key",
|
"api_key",
|
||||||
@@ -516,25 +523,25 @@ class TestCustomDetector:
|
|||||||
strategy="redact",
|
strategy="redact",
|
||||||
)
|
)
|
||||||
|
|
||||||
state = {"messages": [HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")]}
|
state = AgentState[Any](messages=[HumanMessage("Key: sk-abcdefghijklmnopqrstuvwxyz123456")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "[REDACTED_API_KEY]" in result["messages"][0].content
|
assert "[REDACTED_API_KEY]" in result["messages"][0].content
|
||||||
|
|
||||||
def test_custom_callable_detector(self):
|
def test_custom_callable_detector(self) -> None:
|
||||||
# Custom detector function
|
# Custom detector function
|
||||||
def detect_custom(content):
|
def detect_custom(content: str) -> list[PIIMatch]:
|
||||||
matches = []
|
matches = []
|
||||||
if "CONFIDENTIAL" in content:
|
if "CONFIDENTIAL" in content:
|
||||||
idx = content.index("CONFIDENTIAL")
|
idx = content.index("CONFIDENTIAL")
|
||||||
matches.append(
|
matches.append(
|
||||||
{
|
PIIMatch(
|
||||||
"type": "confidential",
|
type="confidential",
|
||||||
"value": "CONFIDENTIAL",
|
value="CONFIDENTIAL",
|
||||||
"start": idx,
|
start=idx,
|
||||||
"end": idx + 12,
|
end=idx + 12,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
@@ -544,17 +551,17 @@ class TestCustomDetector:
|
|||||||
strategy="redact",
|
strategy="redact",
|
||||||
)
|
)
|
||||||
|
|
||||||
state = {"messages": [HumanMessage("This is CONFIDENTIAL information")]}
|
state = AgentState[Any](messages=[HumanMessage("This is CONFIDENTIAL information")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "[REDACTED_CONFIDENTIAL]" in result["messages"][0].content
|
assert "[REDACTED_CONFIDENTIAL]" in result["messages"][0].content
|
||||||
|
|
||||||
def test_unknown_builtin_type_raises_error(self):
|
def test_unknown_builtin_type_raises_error(self) -> None:
|
||||||
with pytest.raises(ValueError, match="Unknown PII type"):
|
with pytest.raises(ValueError, match="Unknown PII type"):
|
||||||
PIIMiddleware("unknown_type", strategy="redact")
|
PIIMiddleware("unknown_type", strategy="redact")
|
||||||
|
|
||||||
def test_custom_type_without_detector_raises_error(self):
|
def test_custom_type_without_detector_raises_error(self) -> None:
|
||||||
with pytest.raises(ValueError, match="Unknown PII type"):
|
with pytest.raises(ValueError, match="Unknown PII type"):
|
||||||
PIIMiddleware("custom_type", strategy="redact")
|
PIIMiddleware("custom_type", strategy="redact")
|
||||||
|
|
||||||
@@ -562,18 +569,20 @@ class TestCustomDetector:
|
|||||||
class TestMultipleMiddleware:
|
class TestMultipleMiddleware:
|
||||||
"""Test using multiple PII middleware instances."""
|
"""Test using multiple PII middleware instances."""
|
||||||
|
|
||||||
def test_sequential_application(self):
|
def test_sequential_application(self) -> None:
|
||||||
"""Test that multiple PII types are detected when applied sequentially."""
|
"""Test that multiple PII types are detected when applied sequentially."""
|
||||||
# First apply email middleware
|
# First apply email middleware
|
||||||
email_middleware = PIIMiddleware("email", strategy="redact")
|
email_middleware = PIIMiddleware("email", strategy="redact")
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com, IP: 192.168.1.1")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 192.168.1.1")])
|
||||||
result1 = email_middleware.before_model(state, None)
|
result1 = email_middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
# Then apply IP middleware to the result
|
# Then apply IP middleware to the result
|
||||||
ip_middleware = PIIMiddleware("ip", strategy="mask")
|
ip_middleware = PIIMiddleware("ip", strategy="mask")
|
||||||
state_with_email_redacted = {"messages": result1["messages"]}
|
assert result1 is not None
|
||||||
result2 = ip_middleware.before_model(state_with_email_redacted, None)
|
state_with_email_redacted = AgentState[Any](messages=result1["messages"])
|
||||||
|
result2 = ip_middleware.before_model(state_with_email_redacted, Runtime())
|
||||||
|
|
||||||
|
assert result2 is not None
|
||||||
content = result2["messages"][0].content
|
content = result2["messages"][0].content
|
||||||
|
|
||||||
# Email should be redacted
|
# Email should be redacted
|
||||||
@@ -584,9 +593,9 @@ class TestMultipleMiddleware:
|
|||||||
assert "*.*.*.1" in content
|
assert "*.*.*.1" in content
|
||||||
assert "192.168.1.1" not in content
|
assert "192.168.1.1" not in content
|
||||||
|
|
||||||
def test_multiple_pii_middleware_with_create_agent(self):
|
def test_multiple_pii_middleware_with_create_agent(self) -> None:
|
||||||
"""Test that multiple PIIMiddleware instances work together in create_agent."""
|
"""Test that multiple PIIMiddleware instances work together in create_agent."""
|
||||||
model = FakeToolCallingModel(responses=[AIMessage(content="Response received")])
|
model = FakeToolCallingModel()
|
||||||
|
|
||||||
# Multiple PIIMiddleware instances should work because each has a unique name
|
# Multiple PIIMiddleware instances should work because each has a unique name
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
@@ -611,7 +620,7 @@ class TestMultipleMiddleware:
|
|||||||
# IP should be masked
|
# IP should be masked
|
||||||
assert "192.168.1.100" not in content
|
assert "192.168.1.100" not in content
|
||||||
|
|
||||||
def test_custom_detector_for_multiple_types(self):
|
def test_custom_detector_for_multiple_types(self) -> None:
|
||||||
"""Test using a single middleware with custom detector for multiple PII types.
|
"""Test using a single middleware with custom detector for multiple PII types.
|
||||||
|
|
||||||
This is an alternative to using multiple middleware instances,
|
This is an alternative to using multiple middleware instances,
|
||||||
@@ -619,7 +628,7 @@ class TestMultipleMiddleware:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Combine multiple detectors into one
|
# Combine multiple detectors into one
|
||||||
def detect_email_and_ip(content):
|
def detect_email_and_ip(content: str) -> list[PIIMatch]:
|
||||||
return detect_email(content) + detect_ip(content)
|
return detect_email(content) + detect_ip(content)
|
||||||
|
|
||||||
middleware = PIIMiddleware(
|
middleware = PIIMiddleware(
|
||||||
@@ -628,9 +637,10 @@ class TestMultipleMiddleware:
|
|||||||
strategy="redact",
|
strategy="redact",
|
||||||
)
|
)
|
||||||
|
|
||||||
state = {"messages": [HumanMessage("Email: test@example.com, IP: 10.0.0.1")]}
|
state = AgentState[Any](messages=[HumanMessage("Email: test@example.com, IP: 10.0.0.1")])
|
||||||
result = middleware.before_model(state, None)
|
result = middleware.before_model(state, Runtime())
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
content = result["messages"][0].content
|
content = result["messages"][0].content
|
||||||
assert "test@example.com" not in content
|
assert "test@example.com" not in content
|
||||||
assert "10.0.0.1" not in content
|
assert "10.0.0.1" not in content
|
||||||
|
|||||||
Reference in New Issue
Block a user