Harrison/string inplace (#10153)

Co-authored-by: Wrick Talukdar <wrick.talukdar@gmail.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Lucky-Lance <77819606+Lucky-Lance@users.noreply.github.com>
Co-authored-by: 陆徐东 <luxudong@MacBook-Pro.local>
This commit is contained in:
Harrison Chase
2023-09-03 14:25:29 -07:00
committed by GitHub
parent f5af756397
commit 4abe85be57
11 changed files with 1623 additions and 218 deletions

View File

@@ -1,7 +1,6 @@
import asyncio
import importlib
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
@@ -30,14 +29,15 @@ class ComprehendToxicity:
Validate and initialize toxicity processing configuration.
Args:
max_size (int): Maximum sentence size defined in the configuration object.
max_size (int): Maximum sentence size defined in the
configuration object.
Raises:
Exception: If the maximum sentence size exceeds the 5KB limit.
Note:
This function ensures that the NLTK punkt tokenizer is downloaded if not
already present.
This function ensures that the NLTK punkt tokenizer is downloaded
if not already present.
Returns:
None
@@ -63,34 +63,36 @@ class ComprehendToxicity:
Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args:
paragraph (str): The input paragraph to be split into chunks
max_size (int, optional): The maximum size limit in bytes for each chunk
Defaults to 1024.
paragraph (str): The input paragraph to be split into chunks.
max_size (int, optional): The maximum size limit in bytes for
each chunk. Defaults to 1024.
Returns:
List[List[str]]: A list of chunks, where each chunk is a list of sentences
List[List[str]]: A list of chunks, where each chunk is a list
of sentences.
Note:
This function validates the maximum sentence size based on service limits
using the 'toxicity_init_validate' function. It uses the NLTK sentence
tokenizer to split the paragraph into sentences.
This function validates the maximum sentence size based on service
limits using the 'toxicity_init_validate' function. It uses the NLTK
sentence tokenizer to split the paragraph into sentences.
Example:
paragraph = "This is a sample paragraph. It
contains multiple sentences. ..."
chunks = split_paragraph(paragraph, max_size=2048)
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = []
current_chunk = [] # type: ignore
chunks = list() # type: ignore
current_chunk = list() # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size or
# current_chunk has 10 sentences, start a new chunk
# If adding a new sentence exceeds max_size
# or current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
@@ -103,16 +105,12 @@ class ComprehendToxicity:
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
def validate(
self, prompt_value: str, config: Optional[Dict[str, Any]] = None
) -> str:
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check the toxicity of a given text prompt using AWS Comprehend service
and apply actions based on configuration.
Check the toxicity of a given text prompt using AWS
Comprehend service and apply actions based on configuration.
Args:
prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions.
@@ -122,7 +120,7 @@ class ComprehendToxicity:
Raises:
ValueError: If the prompt contains toxic labels and cannot be
processed based on the configuration.
processed based on the configuration.
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
@@ -134,76 +132,34 @@ class ComprehendToxicity:
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
toxicity_found = False
threshold = config.get("threshold")
toxicity_labels = config.get("labels")
if config:
from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
BaseModerationActions,
)
toxicity_found = False
action = config.get("action", BaseModerationActions.STOP)
if action not in [
BaseModerationActions.STOP,
BaseModerationActions.ALLOW,
]:
raise ValueError("Action can either be stop or allow")
threshold = config.get("threshold", 0.5) if config else 0.5
toxicity_labels = config.get("labels", []) if config else []
if action == BaseModerationActions.STOP:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label
and (
not toxicity_labels
or label["Name"] in toxicity_labels
)
and label["Score"] >= threshold
):
toxicity_found = True
break
if action == BaseModerationActions.ALLOW:
if not toxicity_labels:
warnings.warn(
"You have allowed toxic content without specifying "
"any toxicity labels."
)
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
if not toxicity_labels:
for item in response["ResultList"]:
for label in item["Labels"]:
if label["Score"] >= threshold:
toxicity_found = True
break
else:
if response["ResultList"]:
detected_toxic_labels = list()
for item in response["ResultList"]:
detected_toxic_labels.extend(item["Labels"])
if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
raise ModerationToxicityError
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
return prompt_value