mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
Harrison/string inplace (#10153)
Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com> Co-authored-by: Anjan Biswas <anjanavb@amazon.com> Co-authored-by: Jha <nikjha@amazon.com> Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com> Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
This commit is contained in:
parent
f5af756397
commit
4abe85be57
1422
docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
Normal file
1422
docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@ -111,7 +111,9 @@ class TaskExecutor:
|
||||
dep_task = self.id_task_map[dep_id]
|
||||
for k, v in task.args.items():
|
||||
if f"<resource-{dep_id}>" in v:
|
||||
task.args[k].replace(f"<resource-{dep_id}>", dep_task.result)
|
||||
task.args[k] = task.args[k].replace(
|
||||
f"<resource-{dep_id}>", dep_task.result
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
for task in self.tasks:
|
||||
|
@ -5,9 +5,11 @@ from langchain_experimental.comprehend_moderation.base_moderation import BaseMod
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||
BaseModerationCallbackHandler,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
BaseModerationFilters,
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_config import (
|
||||
BaseModerationConfig,
|
||||
ModerationIntentConfig,
|
||||
ModerationPiiConfig,
|
||||
ModerationToxicityConfig,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
|
||||
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
|
||||
@ -15,11 +17,13 @@ from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxi
|
||||
|
||||
__all__ = [
|
||||
"BaseModeration",
|
||||
"BaseModerationActions",
|
||||
"BaseModerationFilters",
|
||||
"ComprehendPII",
|
||||
"ComprehendIntent",
|
||||
"ComprehendToxicity",
|
||||
"BaseModerationConfig",
|
||||
"ModerationPiiConfig",
|
||||
"ModerationToxicityConfig",
|
||||
"ModerationIntentConfig",
|
||||
"BaseModerationCallbackHandler",
|
||||
"AmazonComprehendModerationChain",
|
||||
]
|
||||
|
@ -3,12 +3,13 @@ from typing import Any, Dict, List, Optional
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation import (
|
||||
BaseModeration,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
|
||||
BaseModerationCallbackHandler,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_config import (
|
||||
BaseModerationConfig,
|
||||
)
|
||||
from langchain_experimental.pydantic_v1 import root_validator
|
||||
|
||||
|
||||
@ -21,10 +22,13 @@ class AmazonComprehendModerationChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
"""Key used to fetch/store the input in data containers. Defaults to `input`"""
|
||||
|
||||
moderation_config: Optional[Dict[str, Any]] = None
|
||||
"""Configuration settings for moderation"""
|
||||
moderation_config: BaseModerationConfig = BaseModerationConfig()
|
||||
"""
|
||||
Configuration settings for moderation,
|
||||
defaults to BaseModerationConfig with default values
|
||||
"""
|
||||
|
||||
client: Optional[Any]
|
||||
client: Optional[Any] = None
|
||||
"""boto3 client object for connection to Amazon Comprehend"""
|
||||
|
||||
region_name: Optional[str] = None
|
||||
|
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
@ -15,7 +15,7 @@ class BaseModeration:
|
||||
def __init__(
|
||||
self,
|
||||
client: Any,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Any] = None,
|
||||
moderation_callback: Optional[Any] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
@ -105,6 +105,11 @@ class BaseModeration:
|
||||
self.run_manager.on_text(message)
|
||||
|
||||
def moderate(self, prompt: Any) -> str:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
|
||||
ModerationIntentConfig,
|
||||
ModerationPiiConfig,
|
||||
ModerationToxicityConfig,
|
||||
)
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
|
||||
ModerationIntentionError,
|
||||
ModerationPiiError,
|
||||
@ -115,47 +120,43 @@ class BaseModeration:
|
||||
# convert prompt to text
|
||||
input_text = self._convert_prompt_to_text(prompt=prompt)
|
||||
output_text = str()
|
||||
|
||||
# perform moderation
|
||||
if self.config is None:
|
||||
# In absence of config Action will default to STOP only
|
||||
self._log_message_for_verbose("Running pii validation...\n")
|
||||
pii_validate = self._moderation_class(moderation_class=ComprehendPII)
|
||||
output_text = pii_validate(prompt_value=input_text)
|
||||
filter_functions = {
|
||||
"pii": ComprehendPII,
|
||||
"toxicity": ComprehendToxicity,
|
||||
"intent": ComprehendIntent,
|
||||
}
|
||||
|
||||
self._log_message_for_verbose("Running toxicity validation...\n")
|
||||
toxicity_validate = self._moderation_class(
|
||||
moderation_class=ComprehendToxicity
|
||||
)
|
||||
output_text = toxicity_validate(prompt_value=output_text)
|
||||
filters = self.config.filters # type: ignore
|
||||
|
||||
self._log_message_for_verbose("Running intent validation...\n")
|
||||
intent_validate = self._moderation_class(
|
||||
moderation_class=ComprehendIntent
|
||||
for _filter in filters:
|
||||
filter_name = (
|
||||
"pii"
|
||||
if isinstance(_filter, ModerationPiiConfig)
|
||||
else (
|
||||
"toxicity"
|
||||
if isinstance(_filter, ModerationToxicityConfig)
|
||||
else (
|
||||
"intent"
|
||||
if isinstance(_filter, ModerationIntentConfig)
|
||||
else None
|
||||
)
|
||||
)
|
||||
)
|
||||
output_text = intent_validate(prompt_value=output_text)
|
||||
else:
|
||||
filter_functions = {
|
||||
"pii": ComprehendPII,
|
||||
"toxicity": ComprehendToxicity,
|
||||
"intent": ComprehendIntent,
|
||||
}
|
||||
filters = self.config["filters"]
|
||||
for _filter in filters:
|
||||
filter_name = f"{_filter}"
|
||||
if filter_name in filter_functions:
|
||||
self._log_message_for_verbose(
|
||||
f"Running {filter_name} Validation...\n"
|
||||
)
|
||||
validation_fn = self._moderation_class(
|
||||
moderation_class=filter_functions[filter_name]
|
||||
)
|
||||
input_text = input_text if not output_text else output_text
|
||||
output_text = validation_fn(
|
||||
prompt_value=input_text,
|
||||
config=self.config[filter_name]
|
||||
if filter_name in self.config
|
||||
else None,
|
||||
)
|
||||
if filter_name in filter_functions:
|
||||
self._log_message_for_verbose(
|
||||
f"Running {filter_name} Validation...\n"
|
||||
)
|
||||
validation_fn = self._moderation_class(
|
||||
moderation_class=filter_functions[filter_name]
|
||||
)
|
||||
input_text = input_text if not output_text else output_text
|
||||
output_text = validation_fn(
|
||||
prompt_value=input_text,
|
||||
config=_filter.dict(),
|
||||
)
|
||||
|
||||
# convert text to prompt and return
|
||||
return self._convert_text_to_prompt(prompt=prompt, text=output_text)
|
||||
|
||||
|
@ -28,19 +28,19 @@ class BaseModerationCallbackHandler:
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after PII validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
pass
|
||||
|
||||
async def on_after_toxicity(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
pass
|
||||
|
||||
async def on_after_intent(
|
||||
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run after Toxicity validation is complete."""
|
||||
raise NotImplementedError("Subclasses should implement this async method.")
|
||||
pass
|
||||
|
||||
@property
|
||||
def pii_callback(self) -> bool:
|
||||
|
@ -0,0 +1,51 @@
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModerationPiiConfig(BaseModel):
|
||||
threshold: float = 0.5
|
||||
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""
|
||||
|
||||
labels: List[str] = []
|
||||
"""
|
||||
List of PII Universal Labels.
|
||||
Defaults to `list[]`
|
||||
"""
|
||||
|
||||
redact: bool = False
|
||||
"""Whether to perform redaction of detected PII entities"""
|
||||
|
||||
mask_character: str = "*"
|
||||
"""Redaction mask character in case redact=True, defaults to asterisk (*)"""
|
||||
|
||||
|
||||
class ModerationToxicityConfig(BaseModel):
|
||||
threshold: float = 0.5
|
||||
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""
|
||||
|
||||
labels: List[str] = []
|
||||
"""List of toxic labels, defaults to `list[]`"""
|
||||
|
||||
|
||||
class ModerationIntentConfig(BaseModel):
|
||||
threshold: float = 0.5
|
||||
"""
|
||||
Threshold for Intent classification
|
||||
confidence score, defaults to 0.5 i.e. 50%
|
||||
"""
|
||||
|
||||
|
||||
class BaseModerationConfig(BaseModel):
|
||||
filters: List[
|
||||
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
|
||||
] = [
|
||||
ModerationPiiConfig(),
|
||||
ModerationToxicityConfig(),
|
||||
ModerationIntentConfig(),
|
||||
]
|
||||
"""
|
||||
Filters applied to the moderation chain, defaults to
|
||||
`[ModerationPiiConfig(), ModerationToxicityConfig(),
|
||||
ModerationIntentConfig()]`
|
||||
"""
|
@ -1,12 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BaseModerationActions(Enum):
|
||||
STOP = 1
|
||||
ALLOW = 2
|
||||
|
||||
|
||||
class BaseModerationFilters(str, Enum):
|
||||
PII = "pii"
|
||||
TOXICITY = "toxicity"
|
||||
INTENT = "intent"
|
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationIntentionError,
|
||||
@ -30,20 +29,17 @@ class ComprehendIntent:
|
||||
intent_endpoint = "document-classifier-endpoint/prompt-intent"
|
||||
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||
"""
|
||||
Check and validate the intent of the given prompt text.
|
||||
|
||||
Args:
|
||||
comprehend_client: Comprehend client for intent classification
|
||||
prompt_value (str): The input text to be checked for unintended intent
|
||||
config (Dict[str, Any]): Configuration settings for intent checks
|
||||
prompt_value (str): The input text to be checked for unintended intent.
|
||||
config (Dict[str, Any]): Configuration settings for intent checks.
|
||||
|
||||
Raises:
|
||||
ValueError: If unintended intent is found in the prompt text based
|
||||
on the specified threshold.
|
||||
on the specified threshold.
|
||||
|
||||
Returns:
|
||||
str: The input prompt_value.
|
||||
@ -53,26 +49,16 @@ class ComprehendIntent:
|
||||
Comprehend's classify_document API and raises an error if unintended
|
||||
intent is detected with a score above the specified threshold.
|
||||
|
||||
Example:
|
||||
comprehend_client = boto3.client('comprehend')
|
||||
prompt_text = "Please tell me your credit card information."
|
||||
config = {"threshold": 0.7}
|
||||
checked_prompt = check_intent(comprehend_client, prompt_text, config)
|
||||
"""
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
)
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
action = (
|
||||
config.get("action", BaseModerationActions.STOP)
|
||||
if config
|
||||
else BaseModerationActions.STOP
|
||||
)
|
||||
threshold = config.get("threshold")
|
||||
intent_found = False
|
||||
|
||||
if action == BaseModerationActions.ALLOW:
|
||||
warnings.warn(
|
||||
"You have allowed content with Harmful content."
|
||||
"Defaulting to STOP action..."
|
||||
)
|
||||
action = BaseModerationActions.STOP
|
||||
|
||||
endpoint_arn = self._get_arn()
|
||||
response = self.client.classify_document(
|
||||
Text=prompt_value, EndpointArn=endpoint_arn
|
||||
|
@ -23,33 +23,19 @@ class ComprehendPII:
|
||||
self.callback = callback
|
||||
self.unique_id = unique_id
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import (
|
||||
BaseModerationActions,
|
||||
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||
redact = config.get("redact")
|
||||
return (
|
||||
self._detect_pii(prompt_value=prompt_value, config=config)
|
||||
if redact
|
||||
else self._contains_pii(prompt_value=prompt_value, config=config)
|
||||
)
|
||||
|
||||
if config:
|
||||
action = config.get("action", BaseModerationActions.STOP)
|
||||
if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
|
||||
raise ValueError("Action can either be stop or allow")
|
||||
|
||||
return (
|
||||
self._contains_pii(prompt_value=prompt_value, config=config)
|
||||
if action == BaseModerationActions.STOP
|
||||
else self._detect_pii(prompt_value=prompt_value, config=config)
|
||||
)
|
||||
else:
|
||||
return self._contains_pii(prompt_value=prompt_value)
|
||||
|
||||
def _contains_pii(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
|
||||
"""
|
||||
Checks for Personally Identifiable Information (PII) labels above a
|
||||
specified threshold.
|
||||
|
||||
specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
|
||||
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
|
||||
Args:
|
||||
prompt_value (str): The input text to be checked for PII labels.
|
||||
config (Dict[str, Any]): Configuration for PII check and actions.
|
||||
@ -68,8 +54,8 @@ class ComprehendPII:
|
||||
self.moderation_beacon["moderation_input"] = prompt_value
|
||||
self.moderation_beacon["moderation_output"] = pii_identified
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
pii_labels = config.get("labels", []) if config else []
|
||||
threshold = config.get("threshold")
|
||||
pii_labels = config.get("labels")
|
||||
pii_found = False
|
||||
for entity in pii_identified["Labels"]:
|
||||
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
|
||||
@ -93,7 +79,8 @@ class ComprehendPII:
|
||||
Detects and handles Personally Identifiable Information (PII) entities in the
|
||||
given prompt text using Amazon Comprehend's detect_pii_entities API. The
|
||||
function provides options to redact or stop processing based on the identified
|
||||
PII entities and a provided configuration.
|
||||
PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
|
||||
Entities API.
|
||||
|
||||
Args:
|
||||
prompt_value (str): The input text to be checked for PII entities.
|
||||
@ -143,9 +130,9 @@ class ComprehendPII:
|
||||
if pii_found:
|
||||
raise ModerationPiiError
|
||||
else:
|
||||
threshold = config.get("threshold", 0.5) # type: ignore
|
||||
pii_labels = config.get("labels", []) # type: ignore
|
||||
mask_marker = config.get("mask_character", "*") # type: ignore
|
||||
threshold = config.get("threshold") # type: ignore
|
||||
pii_labels = config.get("labels") # type: ignore
|
||||
mask_marker = config.get("mask_character") # type: ignore
|
||||
pii_found = False
|
||||
|
||||
for entity in pii_identified["Entities"]:
|
||||
@ -157,10 +144,14 @@ class ComprehendPII:
|
||||
pii_found = True
|
||||
char_offset_begin = entity["BeginOffset"]
|
||||
char_offset_end = entity["EndOffset"]
|
||||
|
||||
mask_length = char_offset_end - char_offset_begin + 1
|
||||
masked_part = mask_marker * mask_length
|
||||
|
||||
prompt_value = (
|
||||
prompt_value[:char_offset_begin]
|
||||
+ mask_marker * (char_offset_end - char_offset_begin)
|
||||
+ prompt_value[char_offset_end:]
|
||||
+ masked_part
|
||||
+ prompt_value[char_offset_end + 1 :]
|
||||
)
|
||||
|
||||
if self.callback and self.callback.pii_callback:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
|
||||
ModerationToxicityError,
|
||||
@ -30,14 +29,15 @@ class ComprehendToxicity:
|
||||
Validate and initialize toxicity processing configuration.
|
||||
|
||||
Args:
|
||||
max_size (int): Maximum sentence size defined in the configuration object.
|
||||
max_size (int): Maximum sentence size defined in the
|
||||
configuration object.
|
||||
|
||||
Raises:
|
||||
Exception: If the maximum sentence size exceeds the 5KB limit.
|
||||
|
||||
Note:
|
||||
This function ensures that the NLTK punkt tokenizer is downloaded if not
|
||||
already present.
|
||||
This function ensures that the NLTK punkt tokenizer is downloaded
|
||||
if not already present.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@ -63,34 +63,36 @@ class ComprehendToxicity:
|
||||
Split a paragraph into chunks of sentences, respecting the maximum size limit.
|
||||
|
||||
Args:
|
||||
paragraph (str): The input paragraph to be split into chunks
|
||||
max_size (int, optional): The maximum size limit in bytes for each chunk
|
||||
Defaults to 1024.
|
||||
paragraph (str): The input paragraph to be split into chunks.
|
||||
max_size (int, optional): The maximum size limit in bytes for
|
||||
each chunk. Defaults to 1024.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: A list of chunks, where each chunk is a list of sentences
|
||||
List[List[str]]: A list of chunks, where each chunk is a list
|
||||
of sentences.
|
||||
|
||||
Note:
|
||||
This function validates the maximum sentence size based on service limits
|
||||
using the 'toxicity_init_validate' function. It uses the NLTK sentence
|
||||
tokenizer to split the paragraph into sentences.
|
||||
This function validates the maximum sentence size based on service
|
||||
limits using the 'toxicity_init_validate' function. It uses the NLTK
|
||||
sentence tokenizer to split the paragraph into sentences.
|
||||
|
||||
Example:
|
||||
paragraph = "This is a sample paragraph. It
|
||||
contains multiple sentences. ..."
|
||||
chunks = split_paragraph(paragraph, max_size=2048)
|
||||
"""
|
||||
|
||||
# validate max. sentence size based on Service limits
|
||||
nltk = self._toxicity_init_validate(max_size)
|
||||
|
||||
sentences = nltk.sent_tokenize(prompt_value)
|
||||
|
||||
chunks = []
|
||||
current_chunk = [] # type: ignore
|
||||
chunks = list() # type: ignore
|
||||
current_chunk = list() # type: ignore
|
||||
current_size = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_size = len(sentence.encode("utf-8"))
|
||||
|
||||
# If adding a new sentence exceeds max_size or
|
||||
# current_chunk has 10 sentences, start a new chunk
|
||||
# If adding a new sentence exceeds max_size
|
||||
# or current_chunk has 10 sentences, start a new chunk
|
||||
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
|
||||
if current_chunk: # Avoid appending empty chunks
|
||||
chunks.append(current_chunk)
|
||||
@ -103,16 +105,12 @@ class ComprehendToxicity:
|
||||
# Add any remaining sentences
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def validate(
|
||||
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
def validate(self, prompt_value: str, config: Any = None) -> str:
|
||||
"""
|
||||
Check the toxicity of a given text prompt using AWS Comprehend service
|
||||
and apply actions based on configuration.
|
||||
|
||||
Check the toxicity of a given text prompt using AWS
|
||||
Comprehend service and apply actions based on configuration.
|
||||
Args:
|
||||
prompt_value (str): The text content to be checked for toxicity.
|
||||
config (Dict[str, Any]): Configuration for toxicity checks and actions.
|
||||
@ -122,7 +120,7 @@ class ComprehendToxicity:
|
||||
|
||||
Raises:
|
||||
ValueError: If the prompt contains toxic labels and cannot be
|
||||
processed based on the configuration.
|
||||
processed based on the configuration.
|
||||
"""
|
||||
|
||||
chunks = self._split_paragraph(prompt_value=prompt_value)
|
||||
@ -134,76 +132,34 @@ class ComprehendToxicity:
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
self.moderation_beacon["moderation_input"] = segments # type: ignore
|
||||
self.moderation_beacon["moderation_output"] = response
|
||||
toxicity_found = False
|
||||
threshold = config.get("threshold")
|
||||
toxicity_labels = config.get("labels")
|
||||
|
||||
if config:
|
||||
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
|
||||
BaseModerationActions,
|
||||
)
|
||||
|
||||
toxicity_found = False
|
||||
action = config.get("action", BaseModerationActions.STOP)
|
||||
if action not in [
|
||||
BaseModerationActions.STOP,
|
||||
BaseModerationActions.ALLOW,
|
||||
]:
|
||||
raise ValueError("Action can either be stop or allow")
|
||||
|
||||
threshold = config.get("threshold", 0.5) if config else 0.5
|
||||
toxicity_labels = config.get("labels", []) if config else []
|
||||
|
||||
if action == BaseModerationActions.STOP:
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if (
|
||||
label
|
||||
and (
|
||||
not toxicity_labels
|
||||
or label["Name"] in toxicity_labels
|
||||
)
|
||||
and label["Score"] >= threshold
|
||||
):
|
||||
toxicity_found = True
|
||||
break
|
||||
|
||||
if action == BaseModerationActions.ALLOW:
|
||||
if not toxicity_labels:
|
||||
warnings.warn(
|
||||
"You have allowed toxic content without specifying "
|
||||
"any toxicity labels."
|
||||
)
|
||||
else:
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if (
|
||||
label["Name"] in toxicity_labels
|
||||
and label["Score"] >= threshold
|
||||
):
|
||||
toxicity_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
if toxicity_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_toxicity(
|
||||
self.moderation_beacon, self.unique_id
|
||||
)
|
||||
)
|
||||
if toxicity_found:
|
||||
raise ModerationToxicityError
|
||||
if not toxicity_labels:
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if label["Score"] >= threshold:
|
||||
toxicity_found = True
|
||||
break
|
||||
else:
|
||||
if response["ResultList"]:
|
||||
detected_toxic_labels = list()
|
||||
for item in response["ResultList"]:
|
||||
detected_toxic_labels.extend(item["Labels"])
|
||||
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_toxicity(
|
||||
self.moderation_beacon, self.unique_id
|
||||
)
|
||||
)
|
||||
raise ModerationToxicityError
|
||||
for item in response["ResultList"]:
|
||||
for label in item["Labels"]:
|
||||
if (
|
||||
label["Name"] in toxicity_labels
|
||||
and label["Score"] >= threshold
|
||||
):
|
||||
toxicity_found = True
|
||||
break
|
||||
|
||||
if self.callback and self.callback.toxicity_callback:
|
||||
if toxicity_found:
|
||||
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
|
||||
asyncio.create_task(
|
||||
self.callback.on_after_toxicity(
|
||||
self.moderation_beacon, self.unique_id
|
||||
)
|
||||
)
|
||||
if toxicity_found:
|
||||
raise ModerationToxicityError
|
||||
return prompt_value
|
||||
|
Loading…
Reference in New Issue
Block a user