mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
fix(langchain): normalize custom detector output to prevent KeyError in hash/mask strategies (#35651)
This commit is contained in:
@@ -373,7 +373,25 @@ def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector
|
||||
]
|
||||
|
||||
return regex_detector
|
||||
return detector
|
||||
|
||||
# Wrap the custom callable to normalize its output.
|
||||
# Custom detectors may return dicts with "text" instead of "value"
|
||||
# and may omit "type". Map them to proper PIIMatch objects so that
|
||||
# downstream strategies (hash, mask) can access match["value"].
|
||||
raw_detector = detector
|
||||
|
||||
def _normalizing_detector(content: str) -> list[PIIMatch]:
|
||||
return [
|
||||
PIIMatch(
|
||||
type=m.get("type", pii_type),
|
||||
value=m.get("value", m.get("text", "")),
|
||||
start=m["start"],
|
||||
end=m["end"],
|
||||
)
|
||||
for m in raw_detector(content)
|
||||
]
|
||||
|
||||
return _normalizing_detector
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for PII detection middleware."""
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -557,6 +558,57 @@ class TestCustomDetector:
|
||||
assert result is not None
|
||||
assert "[REDACTED_CONFIDENTIAL]" in result["messages"][0].content
|
||||
|
||||
def test_custom_callable_detector_with_text_key_hash(self) -> None:
|
||||
"""Custom detectors returning 'text' instead of 'value' must work with hash strategy.
|
||||
|
||||
Regression test for https://github.com/langchain-ai/langchain/issues/35647:
|
||||
Custom detectors documented to return {"text", "start", "end"} caused
|
||||
KeyError: 'value' when used with hash or mask strategies.
|
||||
"""
|
||||
|
||||
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
|
||||
return [
|
||||
{"text": m.group(), "start": m.start(), "end": m.end()}
|
||||
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
|
||||
]
|
||||
|
||||
middleware = PIIMiddleware(
|
||||
"indian_phone",
|
||||
detector=detect_phone,
|
||||
strategy="hash",
|
||||
apply_to_input=True,
|
||||
)
|
||||
|
||||
state = AgentState[Any](messages=[HumanMessage("Call +91 9876543210")])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "<indian_phone_hash:" in result["messages"][0].content
|
||||
assert "+91 9876543210" not in result["messages"][0].content
|
||||
|
||||
def test_custom_callable_detector_with_text_key_mask(self) -> None:
|
||||
"""Custom detectors returning 'text' instead of 'value' must work with mask strategy."""
|
||||
|
||||
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
|
||||
return [
|
||||
{"text": m.group(), "start": m.start(), "end": m.end()}
|
||||
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
|
||||
]
|
||||
|
||||
middleware = PIIMiddleware(
|
||||
"indian_phone",
|
||||
detector=detect_phone,
|
||||
strategy="mask",
|
||||
apply_to_input=True,
|
||||
)
|
||||
|
||||
state = AgentState[Any](messages=[HumanMessage("Call +91 9876543210")])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "****" in result["messages"][0].content
|
||||
assert "+91 9876543210" not in result["messages"][0].content
|
||||
|
||||
def test_unknown_builtin_type_raises_error(self) -> None:
|
||||
with pytest.raises(ValueError, match="Unknown PII type"):
|
||||
PIIMiddleware("unknown_type", strategy="redact")
|
||||
|
||||
Reference in New Issue
Block a user