mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
Comprehend Moderation 0.2 (#11730)
This PR replaces the previous `Intent` check with the new `Prompt Safety` check. The logic and steps to enable chain moderation via the Amazon Comprehend service, allowing you to detect and redact PII, Toxic, and Prompt Safety information in the LLM prompt or answer remains unchanged. This implementation updates the code and configuration types with respect to `Prompt Safety`. ### Usage sample ```python from langchain_experimental.comprehend_moderation import (BaseModerationConfig, ModerationPromptSafetyConfig, ModerationPiiConfig, ModerationToxicityConfig ) pii_config = ModerationPiiConfig( labels=["SSN"], redact=True, mask_character="X" ) toxicity_config = ModerationToxicityConfig( threshold=0.5 ) prompt_safety_config = ModerationPromptSafetyConfig( threshold=0.5 ) moderation_config = BaseModerationConfig( filters=[pii_config, toxicity_config, prompt_safety_config] ) comp_moderation_with_config = AmazonComprehendModerationChain( moderation_config=moderation_config, #specify the configuration client=comprehend_client, #optionally pass the Boto3 Client verbose=True ) template = """Question: {question} Answer:""" prompt = PromptTemplate(template=template, input_variables=["question"]) responses = [ "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here." ] llm = FakeListLLM(responses=responses) llm_chain = LLMChain(prompt=prompt, llm=llm) chain = ( prompt | comp_moderation_with_config | {llm_chain.input_keys[0]: lambda x: x['output'] } | llm_chain | { "input": lambda x: x['text'] } | comp_moderation_with_config ) try: response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"}) except Exception as e: print(str(e)) else: print(response['output']) ``` ### Output ```python > Entering new AmazonComprehendModerationChain chain... Running AmazonComprehendModerationChain... Running pii Validation... Running toxicity Validation... Running prompt safety Validation... > Finished chain. > Entering new AmazonComprehendModerationChain chain... Running AmazonComprehendModerationChain... Running pii Validation... Running toxicity Validation... Running prompt safety Validation... > Finished chain. Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like XXXXXXXXXXXX John Doe's phone number is (999)253-9876. ``` --------- Co-authored-by: Jha <nikjha@amazon.com> Co-authored-by: Anjan Biswas <anjanavb@amazon.com> Co-authored-by: Anjan Biswas <84933469+anjanvb@users.noreply.github.com>
This commit is contained in:
@@ -7,23 +7,25 @@ from langchain_experimental.comprehend_moderation.base_moderation_callbacks impo
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_config import (
|
||||
BaseModerationConfig,
|
||||
ModerationIntentConfig,
|
||||
ModerationPiiConfig,
|
||||
ModerationPromptSafetyConfig,
|
||||
ModerationToxicityConfig,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
||||
ComprehendPromptSafety,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
|
||||
|
||||
__all__ = [
|
||||
"BaseModeration",
|
||||
"ComprehendPII",
|
||||
"ComprehendIntent",
|
||||
"ComprehendPromptSafety",
|
||||
"ComprehendToxicity",
|
||||
"BaseModerationConfig",
|
||||
"ModerationPiiConfig",
|
||||
"ModerationToxicityConfig",
|
||||
"ModerationIntentConfig",
|
||||
"ModerationPromptSafetyConfig",
|
||||
"BaseModerationCallbackHandler",
|
||||
"AmazonComprehendModerationChain",
|
||||
]
|
||||
|
@@ -6,8 +6,10 @@ from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
from langchain_experimental.comprehend_moderation.prompt_safety import (
|
||||
ComprehendPromptSafety,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
|
||||
|
||||
|
||||
@@ -109,13 +111,13 @@ class BaseModeration:
|
||||
|
||||
def moderate(self, prompt: Any) -> str:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
|
||||
ModerationIntentConfig,
|
||||
ModerationPiiConfig,
|
||||
ModerationPromptSafetyConfig,
|
||||
ModerationToxicityConfig,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
||||
ModerationIntentionError,
|
||||
ModerationPiiError,
|
||||
ModerationPromptSafetyError,
|
||||
ModerationToxicityError,
|
||||
)
|
||||
|
||||
@@ -128,7 +130,7 @@ class BaseModeration:
|
||||
filter_functions = {
|
||||
"pii": ComprehendPII,
|
||||
"toxicity": ComprehendToxicity,
|
||||
"intent": ComprehendIntent,
|
||||
"prompt_safety": ComprehendPromptSafety,
|
||||
}
|
||||
|
||||
filters = self.config.filters # type: ignore
|
||||
@@ -141,8 +143,8 @@ class BaseModeration:
|
||||
"toxicity"
|
||||
if isinstance(_filter, ModerationToxicityConfig)
|
||||
else (
|
||||
"intent"
|
||||
if isinstance(_filter, ModerationIntentConfig)
|
||||
"prompt_safety"
|
||||
if isinstance(_filter, ModerationPromptSafetyConfig)
|
||||
else None
|
||||
)
|
||||
)
|
||||
@@ -171,7 +173,7 @@ class BaseModeration:
|
||||
f"Found Toxic content..stopping..\n{str(e)}\n"
|
||||
)
|
||||
raise e
|
||||
except ModerationIntentionError as e:
|
||||
except ModerationPromptSafetyError as e:
|
||||
self._log_message_for_verbose(
|
||||
f"Found Harmful intention..stopping..\n{str(e)}\n"
|
||||
)
|
||||
|
@@ -11,12 +11,13 @@ class BaseModerationCallbackHandler:
|
||||
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
|
||||
)
|
||||
and self._is_method_unchanged(
|
||||
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
|
||||
BaseModerationCallbackHandler.on_after_prompt_safety,
|
||||
self.on_after_prompt_safety,
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Subclasses must override at least one of on_after_pii(), "
|
||||
"on_after_toxicity(), or on_after_intent() functions."
|
||||
"on_after_toxicity(), or on_after_prompt_safety() functions."
|
||||
)
|
||||
|
||||
def _is_method_unchanged(
|
||||
@@ -36,10 +37,10 @@ class BaseModerationCallbackHandler:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
pass
|
||||
|
||||
async def on_after_intent(
|
||||
async def on_after_prompt_safety(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
"""Run after Prompt Safety validation is complete."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -57,8 +58,8 @@ class BaseModerationCallbackHandler:
|
||||
)
|
||||
|
||||
@property
|
||||
def intent_callback(self) -> bool:
|
||||
def prompt_safety_callback(self) -> bool:
|
||||
return (
|
||||
self.on_after_intent.__func__ # type: ignore
|
||||
is not BaseModerationCallbackHandler.on_after_intent
|
||||
self.on_after_prompt_safety.__func__ # type: ignore
|
||||
is not BaseModerationCallbackHandler.on_after_prompt_safety
|
||||
)
|
||||
|
@@ -28,24 +28,26 @@ class ModerationToxicityConfig(BaseModel):
|
||||
"""List of toxic labels, defaults to `list[]`"""
|
||||
|
||||
|
||||
class ModerationIntentConfig(BaseModel):
|
||||
class ModerationPromptSafetyConfig(BaseModel):
|
||||
threshold: float = 0.5
|
||||
"""
|
||||
Threshold for Intent classification
|
||||
Threshold for Prompt Safety classification
|
||||
confidence score, defaults to 0.5 i.e. 50%
|
||||
"""
|
||||
|
||||
|
||||
class BaseModerationConfig(BaseModel):
|
||||
filters: List[
|
||||
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
|
||||
Union[
|
||||
ModerationPiiConfig, ModerationToxicityConfig, ModerationPromptSafetyConfig
|
||||
]
|
||||
] = [
|
||||
ModerationPiiConfig(),
|
||||
ModerationToxicityConfig(),
|
||||
ModerationIntentConfig(),
|
||||
ModerationPromptSafetyConfig(),
|
||||
]
|
||||
"""
|
||||
Filters applied to the moderation chain, defaults to
|
||||
`[ModerationPiiConfig(), ModerationToxicityConfig(),
|
||||
ModerationIntentConfig()]`
|
||||
ModerationPromptSafetyConfig()]`
|
||||
"""
|
||||
|
@@ -26,7 +26,7 @@ class ModerationToxicityError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ModerationIntentionError(Exception):
|
||||
class ModerationPromptSafetyError(Exception):
|
||||
"""Exception raised if Intention entities are detected.
|
||||
|
||||
Attributes:
|
||||
@@ -35,9 +35,7 @@ class ModerationIntentionError(Exception):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = (
|
||||
"The prompt indicates an un-desired intent and " "cannot be processed"
|
||||
),
|
||||
message: str = ("The prompt is unsafe and cannot be processed"),
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
@@ -2,11 +2,11 @@ import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationIntentionError,
|
||||
ModerationPromptSafetyError,
|
||||
)
|
||||
|
||||
|
||||
class ComprehendIntent:
|
||||
class ComprehendPromptSafety:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
@@ -17,7 +17,7 @@ class ComprehendIntent:
|
||||
self.client = client
|
||||
self.moderation_beacon = {
|
||||
"moderation_chain_id": chain_id,
|
||||
"moderation_type": "Intent",
|
||||
"moderation_type": "PromptSafety",
|
||||
"moderation_status": "LABELS_NOT_FOUND",
|
||||
}
|
||||
self.callback = callback
|
||||
@@ -26,62 +26,62 @@ class ComprehendIntent:
|
||||
def _get_arn(self) -> str:
|
||||
region_name = self.client.meta.region_name
|
||||
service = "comprehend"
|
||||
intent_endpoint = "document-classifier-endpoint/prompt-intent"
|
||||
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
|
||||
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
|
||||
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"
|
||||
|
||||
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||
"""
|
||||
Check and validate the intent of the given prompt text.
|
||||
Check and validate the safety of the given prompt text.
|
||||
|
||||
Args:
|
||||
prompt_value (str): The input text to be checked for unintended intent.
|
||||
config (Dict[str, Any]): Configuration settings for intent checks.
|
||||
prompt_value (str): The input text to be checked for unsafe text.
|
||||
config (Dict[str, Any]): Configuration settings for prompt safety checks.
|
||||
|
||||
Raises:
|
||||
ValueError: If unintended intent is found in the prompt text based
|
||||
ValueError: If unsafe prompt is found in the prompt text based
|
||||
on the specified threshold.
|
||||
|
||||
Returns:
|
||||
str: The input prompt_value.
|
||||
|
||||
Note:
|
||||
This function checks the intent of the provided prompt text using
|
||||
Comprehend's classify_document API and raises an error if unintended
|
||||
intent is detected with a score above the specified threshold.
|
||||
This function checks the safety of the provided prompt text using
|
||||
Comprehend's classify_document API and raises an error if unsafe
|
||||
text is detected with a score above the specified threshold.
|
||||
|
||||
Example:
|
||||
comprehend_client = boto3.client('comprehend')
|
||||
prompt_text = "Please tell me your credit card information."
|
||||
config = {"threshold": 0.7}
|
||||
checked_prompt = check_intent(comprehend_client, prompt_text, config)
|
||||
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
|
||||
"""
|
||||
|
||||
threshold = config.get("threshold")
|
||||
intent_found = False
|
||||
unsafe_prompt = False
|
||||
|
||||
endpoint_arn = self._get_arn()
|
||||
response = self.client.classify_document(
|
||||
Text=prompt_value, EndpointArn=endpoint_arn
|
||||
)
|
||||
|
||||
if self.callback and self.callback.intent_callback:
|
||||
if self.callback and self.callback.prompt_safety_callback:
|
||||
self.moderation_beacon["moderation_input"] = prompt_value
|
||||
self.moderation_beacon["moderation_output"] = response
|
||||
|
||||
for class_result in response["Classes"]:
|
||||
if (
|
||||
class_result["Score"] >= threshold
|
||||
and class_result["Name"] == "UNDESIRED_PROMPT"
|
||||
and class_result["Name"] == "UNSAFE_PROMPT"
|
||||
):
|
||||
intent_found = True
|
||||
unsafe_prompt = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.intent_callback:
|
||||
if intent_found:
|
||||
if unsafe_prompt:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
|
||||
)
|
||||
if intent_found:
|
||||
raise ModerationIntentionError
|
||||
if unsafe_prompt:
|
||||
raise ModerationPromptSafetyError
|
||||
return prompt_value
|
Reference in New Issue
Block a user