mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
Harrison/string inplace (#10153)
Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com> Co-authored-by: Anjan Biswas <anjanavb@amazon.com> Co-authored-by: Jha <nikjha@amazon.com> Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com> Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
This commit is contained in:
parent
f5af756397
commit
4abe85be57
1422
docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
Normal file
1422
docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@ -111,7 +111,9 @@ class TaskExecutor:
|
|||||||
dep_task = self.id_task_map[dep_id]
|
dep_task = self.id_task_map[dep_id]
|
||||||
for k, v in task.args.items():
|
for k, v in task.args.items():
|
||||||
if f"<resource-{dep_id}>" in v:
|
if f"<resource-{dep_id}>" in v:
|
||||||
task.args[k].replace(f"<resource-{dep_id}>", dep_task.result)
|
task.args[k] = task.args[k].replace(
|
||||||
|
f"<resource-{dep_id}>", dep_task.result
|
||||||
|
)
|
||||||
|
|
||||||
def run(self) -> str:
|
def run(self) -> str:
|
||||||
for task in self.tasks:
|
for task in self.tasks:
|
||||||
|
@ -5,9 +5,11 @@ from langchain_experimental.comprehend_moderation.base_moderation import BaseMod
|
|||||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||||
BaseModerationCallbackHandler,
|
BaseModerationCallbackHandler,
|
||||||
)
|
)
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
from langchain_experimental.comprehend_moderation.base_moderation_config import (
|
||||||
BaseModerationActions,
|
BaseModerationConfig,
|
||||||
BaseModerationFilters,
|
ModerationIntentConfig,
|
||||||
|
ModerationPiiConfig,
|
||||||
|
ModerationToxicityConfig,
|
||||||
)
|
)
|
||||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||||
@ -15,11 +17,13 @@ from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxi
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModeration",
|
"BaseModeration",
|
||||||
"BaseModerationActions",
|
|
||||||
"BaseModerationFilters",
|
|
||||||
"ComprehendPII",
|
"ComprehendPII",
|
||||||
"ComprehendIntent",
|
"ComprehendIntent",
|
||||||
"ComprehendToxicity",
|
"ComprehendToxicity",
|
||||||
|
"BaseModerationConfig",
|
||||||
|
"ModerationPiiConfig",
|
||||||
|
"ModerationToxicityConfig",
|
||||||
|
"ModerationIntentConfig",
|
||||||
"BaseModerationCallbackHandler",
|
"BaseModerationCallbackHandler",
|
||||||
"AmazonComprehendModerationChain",
|
"AmazonComprehendModerationChain",
|
||||||
]
|
]
|
||||||
|
@ -3,12 +3,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation import (
|
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
|
||||||
BaseModeration,
|
|
||||||
)
|
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||||
BaseModerationCallbackHandler,
|
BaseModerationCallbackHandler,
|
||||||
)
|
)
|
||||||
|
from langchain_experimental.comprehend_moderation.base_moderation_config import (
|
||||||
|
BaseModerationConfig,
|
||||||
|
)
|
||||||
from langchain_experimental.pydantic_v1 import root_validator
|
from langchain_experimental.pydantic_v1 import root_validator
|
||||||
|
|
||||||
|
|
||||||
@ -21,10 +22,13 @@ class AmazonComprehendModerationChain(Chain):
|
|||||||
input_key: str = "input" #: :meta private:
|
input_key: str = "input" #: :meta private:
|
||||||
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
|
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
|
||||||
|
|
||||||
moderation_config: Optional[Dict[str, Any]] = None
|
moderation_config: BaseModerationConfig = BaseModerationConfig()
|
||||||
"""Configuration settings for moderation"""
|
"""
|
||||||
|
Configuration settings for moderation,
|
||||||
|
defaults to BaseModerationConfig with default values
|
||||||
|
"""
|
||||||
|
|
||||||
client: Optional[Any]
|
client: Optional[Any] = None
|
||||||
"""boto3 client object for connection to Amazon Comprehend"""
|
"""boto3 client object for connection to Amazon Comprehend"""
|
||||||
|
|
||||||
region_name: Optional[str] = None
|
region_name: Optional[str] = None
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
@ -15,7 +15,7 @@ class BaseModeration:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: Any,
|
client: Any,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Any] = None,
|
||||||
moderation_callback: Optional[Any] = None,
|
moderation_callback: Optional[Any] = None,
|
||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
@ -105,6 +105,11 @@ class BaseModeration:
|
|||||||
self.run_manager.on_text(message)
|
self.run_manager.on_text(message)
|
||||||
|
|
||||||
def moderate(self, prompt: Any) -> str:
|
def moderate(self, prompt: Any) -> str:
|
||||||
|
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
|
||||||
|
ModerationIntentConfig,
|
||||||
|
ModerationPiiConfig,
|
||||||
|
ModerationToxicityConfig,
|
||||||
|
)
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
||||||
ModerationIntentionError,
|
ModerationIntentionError,
|
||||||
ModerationPiiError,
|
ModerationPiiError,
|
||||||
@ -115,47 +120,43 @@ class BaseModeration:
|
|||||||
# convert prompt to text
|
# convert prompt to text
|
||||||
input_text = self._convert_prompt_to_text(prompt=prompt)
|
input_text = self._convert_prompt_to_text(prompt=prompt)
|
||||||
output_text = str()
|
output_text = str()
|
||||||
|
|
||||||
# perform moderation
|
# perform moderation
|
||||||
if self.config is None:
|
filter_functions = {
|
||||||
# In absence of config Action will default to STOP only
|
"pii": ComprehendPII,
|
||||||
self._log_message_for_verbose("Running pii validation...\n")
|
"toxicity": ComprehendToxicity,
|
||||||
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
|
"intent": ComprehendIntent,
|
||||||
output_text = pii_validate(prompt_value=input_text)
|
}
|
||||||
|
|
||||||
self._log_message_for_verbose("Running toxicity validation...\n")
|
filters = self.config.filters # type: ignore
|
||||||
toxicity_validate = self._moderation_class(
|
|
||||||
moderation_class=ComprehendToxicity
|
|
||||||
)
|
|
||||||
output_text = toxicity_validate(prompt_value=output_text)
|
|
||||||
|
|
||||||
self._log_message_for_verbose("Running intent validation...\n")
|
for _filter in filters:
|
||||||
intent_validate = self._moderation_class(
|
filter_name = (
|
||||||
moderation_class=ComprehendIntent
|
"pii"
|
||||||
|
if isinstance(_filter, ModerationPiiConfig)
|
||||||
|
else (
|
||||||
|
"toxicity"
|
||||||
|
if isinstance(_filter, ModerationToxicityConfig)
|
||||||
|
else (
|
||||||
|
"intent"
|
||||||
|
if isinstance(_filter, ModerationIntentConfig)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
output_text = intent_validate(prompt_value=output_text)
|
if filter_name in filter_functions:
|
||||||
else:
|
self._log_message_for_verbose(
|
||||||
filter_functions = {
|
f"Running {filter_name} Validation...\n"
|
||||||
"pii": ComprehendPII,
|
)
|
||||||
"toxicity": ComprehendToxicity,
|
validation_fn = self._moderation_class(
|
||||||
"intent": ComprehendIntent,
|
moderation_class=filter_functions[filter_name]
|
||||||
}
|
)
|
||||||
filters = self.config["filters"]
|
input_text = input_text if not output_text else output_text
|
||||||
for _filter in filters:
|
output_text = validation_fn(
|
||||||
filter_name = f"{_filter}"
|
prompt_value=input_text,
|
||||||
if filter_name in filter_functions:
|
config=_filter.dict(),
|
||||||
self._log_message_for_verbose(
|
)
|
||||||
f"Running {filter_name} Validation...\n"
|
|
||||||
)
|
|
||||||
validation_fn = self._moderation_class(
|
|
||||||
moderation_class=filter_functions[filter_name]
|
|
||||||
)
|
|
||||||
input_text = input_text if not output_text else output_text
|
|
||||||
output_text = validation_fn(
|
|
||||||
prompt_value=input_text,
|
|
||||||
config=self.config[filter_name]
|
|
||||||
if filter_name in self.config
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
# convert text to prompt and return
|
# convert text to prompt and return
|
||||||
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
|
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
|
||||||
|
|
||||||
|
@ -28,19 +28,19 @@ class BaseModerationCallbackHandler:
|
|||||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run after PII validation is complete."""
|
"""Run after PII validation is complete."""
|
||||||
raise NotImplementedError("Subclasses should implement this async method.")
|
pass
|
||||||
|
|
||||||
async def on_after_toxicity(
|
async def on_after_toxicity(
|
||||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run after Toxicity validation is complete."""
|
"""Run after Toxicity validation is complete."""
|
||||||
raise NotImplementedError("Subclasses should implement this async method.")
|
pass
|
||||||
|
|
||||||
async def on_after_intent(
|
async def on_after_intent(
|
||||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run after Toxicity validation is complete."""
|
"""Run after Toxicity validation is complete."""
|
||||||
raise NotImplementedError("Subclasses should implement this async method.")
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pii_callback(self) -> bool:
|
def pii_callback(self) -> bool:
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationPiiConfig(BaseModel):
|
||||||
|
threshold: float = 0.5
|
||||||
|
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""
|
||||||
|
|
||||||
|
labels: List[str] = []
|
||||||
|
"""
|
||||||
|
List of PII Universal Labels.
|
||||||
|
Defaults to `list[]`
|
||||||
|
"""
|
||||||
|
|
||||||
|
redact: bool = False
|
||||||
|
"""Whether to perform redaction of detected PII entities"""
|
||||||
|
|
||||||
|
mask_character: str = "*"
|
||||||
|
"""Redaction mask character in case redact=True, defaults to asterisk (*)"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationToxicityConfig(BaseModel):
|
||||||
|
threshold: float = 0.5
|
||||||
|
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""
|
||||||
|
|
||||||
|
labels: List[str] = []
|
||||||
|
"""List of toxic labels, defaults to `list[]`"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationIntentConfig(BaseModel):
|
||||||
|
threshold: float = 0.5
|
||||||
|
"""
|
||||||
|
Threshold for Intent classification
|
||||||
|
confidence score, defaults to 0.5 i.e. 50%
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModerationConfig(BaseModel):
|
||||||
|
filters: List[
|
||||||
|
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
|
||||||
|
] = [
|
||||||
|
ModerationPiiConfig(),
|
||||||
|
ModerationToxicityConfig(),
|
||||||
|
ModerationIntentConfig(),
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
Filters applied to the moderation chain, defaults to
|
||||||
|
`[ModerationPiiConfig(), ModerationToxicityConfig(),
|
||||||
|
ModerationIntentConfig()]`
|
||||||
|
"""
|
@ -1,12 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModerationActions(Enum):
|
|
||||||
STOP = 1
|
|
||||||
ALLOW = 2
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModerationFilters(str, Enum):
|
|
||||||
PII = "pii"
|
|
||||||
TOXICITY = "toxicity"
|
|
||||||
INTENT = "intent"
|
|
@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import warnings
|
from typing import Any, Optional
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||||
ModerationIntentionError,
|
ModerationIntentionError,
|
||||||
@ -30,20 +29,17 @@ class ComprehendIntent:
|
|||||||
intent_endpoint = "document-classifier-endpoint/prompt-intent"
|
intent_endpoint = "document-classifier-endpoint/prompt-intent"
|
||||||
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
|
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
|
||||||
|
|
||||||
def validate(
|
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Check and validate the intent of the given prompt text.
|
Check and validate the intent of the given prompt text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
comprehend_client: Comprehend client for intent classification
|
prompt_value (str): The input text to be checked for unintended intent.
|
||||||
prompt_value (str): The input text to be checked for unintended intent
|
config (Dict[str, Any]): Configuration settings for intent checks.
|
||||||
config (Dict[str, Any]): Configuration settings for intent checks
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If unintended intent is found in the prompt text based
|
ValueError: If unintended intent is found in the prompt text based
|
||||||
on the specified threshold.
|
on the specified threshold.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The input prompt_value.
|
str: The input prompt_value.
|
||||||
@ -53,26 +49,16 @@ class ComprehendIntent:
|
|||||||
Comprehend's classify_document API and raises an error if unintended
|
Comprehend's classify_document API and raises an error if unintended
|
||||||
intent is detected with a score above the specified threshold.
|
intent is detected with a score above the specified threshold.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
comprehend_client = boto3.client('comprehend')
|
||||||
|
prompt_text = "Please tell me your credit card information."
|
||||||
|
config = {"threshold": 0.7}
|
||||||
|
checked_prompt = check_intent(comprehend_client, prompt_text, config)
|
||||||
"""
|
"""
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
|
||||||
BaseModerationActions,
|
|
||||||
)
|
|
||||||
|
|
||||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
threshold = config.get("threshold")
|
||||||
action = (
|
|
||||||
config.get("action", BaseModerationActions.STOP)
|
|
||||||
if config
|
|
||||||
else BaseModerationActions.STOP
|
|
||||||
)
|
|
||||||
intent_found = False
|
intent_found = False
|
||||||
|
|
||||||
if action == BaseModerationActions.ALLOW:
|
|
||||||
warnings.warn(
|
|
||||||
"You have allowed content with Harmful content."
|
|
||||||
"Defaulting to STOP action..."
|
|
||||||
)
|
|
||||||
action = BaseModerationActions.STOP
|
|
||||||
|
|
||||||
endpoint_arn = self._get_arn()
|
endpoint_arn = self._get_arn()
|
||||||
response = self.client.classify_document(
|
response = self.client.classify_document(
|
||||||
Text=prompt_value, EndpointArn=endpoint_arn
|
Text=prompt_value, EndpointArn=endpoint_arn
|
||||||
|
@ -23,33 +23,19 @@ class ComprehendPII:
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
|
|
||||||
def validate(
|
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
redact = config.get("redact")
|
||||||
) -> str:
|
return (
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
self._detect_pii(prompt_value=prompt_value, config=config)
|
||||||
BaseModerationActions,
|
if redact
|
||||||
|
else self._contains_pii(prompt_value=prompt_value, config=config)
|
||||||
)
|
)
|
||||||
|
|
||||||
if config:
|
def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
|
||||||
action = config.get("action", BaseModerationActions.STOP)
|
|
||||||
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
|
|
||||||
raise ValueError("Action can either be stop or allow")
|
|
||||||
|
|
||||||
return (
|
|
||||||
self._contains_pii(prompt_value=prompt_value, config=config)
|
|
||||||
if action == BaseModerationActions.STOP
|
|
||||||
else self._detect_pii(prompt_value=prompt_value, config=config)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._contains_pii(prompt_value=prompt_value)
|
|
||||||
|
|
||||||
def _contains_pii(
|
|
||||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Checks for Personally Identifiable Information (PII) labels above a
|
Checks for Personally Identifiable Information (PII) labels above a
|
||||||
specified threshold.
|
specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
|
||||||
|
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
|
||||||
Args:
|
Args:
|
||||||
prompt_value (str): The input text to be checked for PII labels.
|
prompt_value (str): The input text to be checked for PII labels.
|
||||||
config (Dict[str, Any]): Configuration for PII check and actions.
|
config (Dict[str, Any]): Configuration for PII check and actions.
|
||||||
@ -68,8 +54,8 @@ class ComprehendPII:
|
|||||||
self.moderation_beacon["moderation_input"] = prompt_value
|
self.moderation_beacon["moderation_input"] = prompt_value
|
||||||
self.moderation_beacon["moderation_output"] = pii_identified
|
self.moderation_beacon["moderation_output"] = pii_identified
|
||||||
|
|
||||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
threshold = config.get("threshold")
|
||||||
pii_labels = config.get("labels", []) if config else []
|
pii_labels = config.get("labels")
|
||||||
pii_found = False
|
pii_found = False
|
||||||
for entity in pii_identified["Labels"]:
|
for entity in pii_identified["Labels"]:
|
||||||
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
|
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
|
||||||
@ -93,7 +79,8 @@ class ComprehendPII:
|
|||||||
Detects and handles Personally Identifiable Information (PII) entities in the
|
Detects and handles Personally Identifiable Information (PII) entities in the
|
||||||
given prompt text using Amazon Comprehend's detect_pii_entities API. The
|
given prompt text using Amazon Comprehend's detect_pii_entities API. The
|
||||||
function provides options to redact or stop processing based on the identified
|
function provides options to redact or stop processing based on the identified
|
||||||
PII entities and a provided configuration.
|
PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
|
||||||
|
Entities API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_value (str): The input text to be checked for PII entities.
|
prompt_value (str): The input text to be checked for PII entities.
|
||||||
@ -143,9 +130,9 @@ class ComprehendPII:
|
|||||||
if pii_found:
|
if pii_found:
|
||||||
raise ModerationPiiError
|
raise ModerationPiiError
|
||||||
else:
|
else:
|
||||||
threshold = config.get("threshold", 0.5) # type: ignore
|
threshold = config.get("threshold") # type: ignore
|
||||||
pii_labels = config.get("labels", []) # type: ignore
|
pii_labels = config.get("labels") # type: ignore
|
||||||
mask_marker = config.get("mask_character", "*") # type: ignore
|
mask_marker = config.get("mask_character") # type: ignore
|
||||||
pii_found = False
|
pii_found = False
|
||||||
|
|
||||||
for entity in pii_identified["Entities"]:
|
for entity in pii_identified["Entities"]:
|
||||||
@ -157,10 +144,14 @@ class ComprehendPII:
|
|||||||
pii_found = True
|
pii_found = True
|
||||||
char_offset_begin = entity["BeginOffset"]
|
char_offset_begin = entity["BeginOffset"]
|
||||||
char_offset_end = entity["EndOffset"]
|
char_offset_end = entity["EndOffset"]
|
||||||
|
|
||||||
|
mask_length = char_offset_end - char_offset_begin + 1
|
||||||
|
masked_part = mask_marker * mask_length
|
||||||
|
|
||||||
prompt_value = (
|
prompt_value = (
|
||||||
prompt_value[:char_offset_begin]
|
prompt_value[:char_offset_begin]
|
||||||
+ mask_marker * (char_offset_end - char_offset_begin)
|
+ masked_part
|
||||||
+ prompt_value[char_offset_end:]
|
+ prompt_value[char_offset_end + 1 :]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.callback and self.callback.pii_callback:
|
if self.callback and self.callback.pii_callback:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import warnings
|
from typing import Any, List, Optional
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||||
ModerationToxicityError,
|
ModerationToxicityError,
|
||||||
@ -30,14 +29,15 @@ class ComprehendToxicity:
|
|||||||
Validate and initialize toxicity processing configuration.
|
Validate and initialize toxicity processing configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_size (int): Maximum sentence size defined in the configuration object.
|
max_size (int): Maximum sentence size defined in the
|
||||||
|
configuration object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the maximum sentence size exceeds the 5KB limit.
|
Exception: If the maximum sentence size exceeds the 5KB limit.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This function ensures that the NLTK punkt tokenizer is downloaded if not
|
This function ensures that the NLTK punkt tokenizer is downloaded
|
||||||
already present.
|
if not already present.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
@ -63,34 +63,36 @@ class ComprehendToxicity:
|
|||||||
Split a paragraph into chunks of sentences, respecting the maximum size limit.
|
Split a paragraph into chunks of sentences, respecting the maximum size limit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
paragraph (str): The input paragraph to be split into chunks
|
paragraph (str): The input paragraph to be split into chunks.
|
||||||
max_size (int, optional): The maximum size limit in bytes for each chunk
|
max_size (int, optional): The maximum size limit in bytes for
|
||||||
Defaults to 1024.
|
each chunk. Defaults to 1024.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[str]]: A list of chunks, where each chunk is a list of sentences
|
List[List[str]]: A list of chunks, where each chunk is a list
|
||||||
|
of sentences.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This function validates the maximum sentence size based on service limits
|
This function validates the maximum sentence size based on service
|
||||||
using the 'toxicity_init_validate' function. It uses the NLTK sentence
|
limits using the 'toxicity_init_validate' function. It uses the NLTK
|
||||||
tokenizer to split the paragraph into sentences.
|
sentence tokenizer to split the paragraph into sentences.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
paragraph = "This is a sample paragraph. It
|
||||||
|
contains multiple sentences. ..."
|
||||||
|
chunks = split_paragraph(paragraph, max_size=2048)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# validate max. sentence size based on Service limits
|
# validate max. sentence size based on Service limits
|
||||||
nltk = self._toxicity_init_validate(max_size)
|
nltk = self._toxicity_init_validate(max_size)
|
||||||
|
|
||||||
sentences = nltk.sent_tokenize(prompt_value)
|
sentences = nltk.sent_tokenize(prompt_value)
|
||||||
|
chunks = list() # type: ignore
|
||||||
chunks = []
|
current_chunk = list() # type: ignore
|
||||||
current_chunk = [] # type: ignore
|
|
||||||
current_size = 0
|
current_size = 0
|
||||||
|
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
sentence_size = len(sentence.encode("utf-8"))
|
sentence_size = len(sentence.encode("utf-8"))
|
||||||
|
# If adding a new sentence exceeds max_size
|
||||||
# If adding a new sentence exceeds max_size or
|
# or current_chunk has 10 sentences, start a new chunk
|
||||||
# current_chunk has 10 sentences, start a new chunk
|
|
||||||
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
|
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
|
||||||
if current_chunk: # Avoid appending empty chunks
|
if current_chunk: # Avoid appending empty chunks
|
||||||
chunks.append(current_chunk)
|
chunks.append(current_chunk)
|
||||||
@ -103,16 +105,12 @@ class ComprehendToxicity:
|
|||||||
# Add any remaining sentences
|
# Add any remaining sentences
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(current_chunk)
|
chunks.append(current_chunk)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def validate(
|
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Check the toxicity of a given text prompt using AWS Comprehend service
|
Check the toxicity of a given text prompt using AWS
|
||||||
and apply actions based on configuration.
|
Comprehend service and apply actions based on configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt_value (str): The text content to be checked for toxicity.
|
prompt_value (str): The text content to be checked for toxicity.
|
||||||
config (Dict[str, Any]): Configuration for toxicity checks and actions.
|
config (Dict[str, Any]): Configuration for toxicity checks and actions.
|
||||||
@ -122,7 +120,7 @@ class ComprehendToxicity:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the prompt contains toxic labels and cannot be
|
ValueError: If the prompt contains toxic labels and cannot be
|
||||||
processed based on the configuration.
|
processed based on the configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chunks = self._split_paragraph(prompt_value=prompt_value)
|
chunks = self._split_paragraph(prompt_value=prompt_value)
|
||||||
@ -134,76 +132,34 @@ class ComprehendToxicity:
|
|||||||
if self.callback and self.callback.toxicity_callback:
|
if self.callback and self.callback.toxicity_callback:
|
||||||
self.moderation_beacon["moderation_input"] = segments # type: ignore
|
self.moderation_beacon["moderation_input"] = segments # type: ignore
|
||||||
self.moderation_beacon["moderation_output"] = response
|
self.moderation_beacon["moderation_output"] = response
|
||||||
|
toxicity_found = False
|
||||||
|
threshold = config.get("threshold")
|
||||||
|
toxicity_labels = config.get("labels")
|
||||||
|
|
||||||
if config:
|
if not toxicity_labels:
|
||||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
|
for item in response["ResultList"]:
|
||||||
BaseModerationActions,
|
for label in item["Labels"]:
|
||||||
)
|
if label["Score"] >= threshold:
|
||||||
|
toxicity_found = True
|
||||||
toxicity_found = False
|
break
|
||||||
action = config.get("action", BaseModerationActions.STOP)
|
|
||||||
if action not in [
|
|
||||||
BaseModerationActions.STOP,
|
|
||||||
BaseModerationActions.ALLOW,
|
|
||||||
]:
|
|
||||||
raise ValueError("Action can either be stop or allow")
|
|
||||||
|
|
||||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
|
||||||
toxicity_labels = config.get("labels", []) if config else []
|
|
||||||
|
|
||||||
if action == BaseModerationActions.STOP:
|
|
||||||
for item in response["ResultList"]:
|
|
||||||
for label in item["Labels"]:
|
|
||||||
if (
|
|
||||||
label
|
|
||||||
and (
|
|
||||||
not toxicity_labels
|
|
||||||
or label["Name"] in toxicity_labels
|
|
||||||
)
|
|
||||||
and label["Score"] >= threshold
|
|
||||||
):
|
|
||||||
toxicity_found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if action == BaseModerationActions.ALLOW:
|
|
||||||
if not toxicity_labels:
|
|
||||||
warnings.warn(
|
|
||||||
"You have allowed toxic content without specifying "
|
|
||||||
"any toxicity labels."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for item in response["ResultList"]:
|
|
||||||
for label in item["Labels"]:
|
|
||||||
if (
|
|
||||||
label["Name"] in toxicity_labels
|
|
||||||
and label["Score"] >= threshold
|
|
||||||
):
|
|
||||||
toxicity_found = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.callback and self.callback.toxicity_callback:
|
|
||||||
if toxicity_found:
|
|
||||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
|
||||||
asyncio.create_task(
|
|
||||||
self.callback.on_after_toxicity(
|
|
||||||
self.moderation_beacon, self.unique_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if toxicity_found:
|
|
||||||
raise ModerationToxicityError
|
|
||||||
else:
|
else:
|
||||||
if response["ResultList"]:
|
for item in response["ResultList"]:
|
||||||
detected_toxic_labels = list()
|
for label in item["Labels"]:
|
||||||
for item in response["ResultList"]:
|
if (
|
||||||
detected_toxic_labels.extend(item["Labels"])
|
label["Name"] in toxicity_labels
|
||||||
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
|
and label["Score"] >= threshold
|
||||||
if self.callback and self.callback.toxicity_callback:
|
):
|
||||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
toxicity_found = True
|
||||||
asyncio.create_task(
|
break
|
||||||
self.callback.on_after_toxicity(
|
|
||||||
self.moderation_beacon, self.unique_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise ModerationToxicityError
|
|
||||||
|
|
||||||
|
if self.callback and self.callback.toxicity_callback:
|
||||||
|
if toxicity_found:
|
||||||
|
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||||
|
asyncio.create_task(
|
||||||
|
self.callback.on_after_toxicity(
|
||||||
|
self.moderation_beacon, self.unique_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if toxicity_found:
|
||||||
|
raise ModerationToxicityError
|
||||||
return prompt_value
|
return prompt_value
|
||||||
|
Loading…
Reference in New Issue
Block a user