diff --git a/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb b/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
new file mode 100644
index 00000000000..03277c16fed
--- /dev/null
+++ b/docs/docs_skeleton/docs/guides/safety/amazon_comprehend_chain.ipynb
@@ -0,0 +1,1396 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "25a3f834-60b7-4c21-bfb4-ad16d30fd3f7",
+ "metadata": {},
+ "source": [
+ "# Amazon Comprehend Moderation Chain\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2c4236d8-4054-473d-84a4-87a4db278a62",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install boto3 nltk"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3f8518ad-c762-413c-b8c9-f1c211fc311d",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import boto3\n",
+ "\n",
+ "comprehend_client = boto3.client('comprehend', \n",
+ " region_name='us-east-1', \n",
+ " aws_access_key_id=\"ASIA6BR6ZDLNQLMEGWHM\",\n",
+ " aws_secret_access_key=\"Y79nefFoOfvgrog6sojSe55xTuKqDJY53BgfrtlG\",\n",
+ " aws_session_token=\"IQoJb3JpZ2luX2VjEIP//////////wEaCXVzLWVhc3QtMSJGMEQCIBvUl0Wj5Gu5GrHB+i5fHkaVc2V1381M7UNRX8EggHORAiB+dG/uKJ4loHn2oAcXIEy6+lfU7wygl4zw/vUo2VItFiqfAghMEAIaDDk2NTQyNTU2ODQ3NSIMfbh8uyoO1XONSkuEKvwBTMxeDCi//9U9LGIwZZzIiHOudQAqR2wlIGZKcw//abSeHNBE1AoDT8ibcqk7EuIt9fwnj1WYiLGmSIWd9/kSZShiKdYg0UpNWyr1/LdeutV5byFAjT21RnWTgSMr0QeSCU698PFusvO1Coph8C75pcqTVYsxi/HypJT8OfB5iCxKgfzx0qD4X6hScpIAEYZhgQXHFBAeubqMkVPYEqSob6fSm1vEI8LkU8HG1N2M2p8TzGCQWo5uBgtNkipxve++bkR+xjiNLIpAN3P1xF2/W/lYlz+4xGsi90aZqIVh/tOvAjg7Yx1Dd5Ir2C0fZc7wbtabzVFlJZ7GFcpcMOX0o6cGOp4BismuW2CJRBmFFpoparqraQaiQBY/VDbQg9KQc/Y6o0oCxkESLUdY6ino3yrheT3W832eAg0RwrmEaQqT8kKGyJFimUxrAF/otNQhySLKuSXLooguammJiQAtgK1EhmuLBUBoLcngxQ31kDqw13g7Ccwuo68fnI/QzQLj5MX+V5VLCSp9VrOzi9XSjmeF/TJQARdZeL3CSeu2pATQc80=\"\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d1f0ba28",
+ "metadata": {},
+ "source": [
+ "Import `AmazonComprehendModerationChain`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "74550d74-3c01-4ba7-ad32-ca66d955d001",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_experimental.comprehend_moderation import AmazonComprehendModerationChain"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f00c338b-de9f-40e5-9295-93c9e26058e3",
+ "metadata": {},
+ "source": [
+ "Initialize an instance of the Amazon Comprehend Moderation Chain to be used with your LLM chain"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cde58cc6-ff83-493a-9aed-93d755f984a7",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "comprehend_moderation = AmazonComprehendModerationChain(\n",
+ " client=comprehend_client, #optional\n",
+ " verbose=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ad646d01-82d2-435a-939b-c450693857ab",
+ "metadata": {},
+ "source": [
+ "Using it with your LLM chain. \n",
+ "\n",
+ "**Note**: The example below uses the _Fake LLM_ from LangChain, but same concept could be applied to other LLMs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0efa1946-d4a9-467a-920a-a8fb78720fc2",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain import PromptTemplate, LLMChain\n",
+ "from langchain.llms.fake import FakeListLLM\n",
+ "from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ModerationPiiError\n",
+ "\n",
+ "template = \"\"\"Question: {question}\n",
+ "\n",
+ "Answer:\"\"\"\n",
+ "\n",
+ "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
+ "\n",
+ "responses = [\n",
+ " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
+ " \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
+ "]\n",
+ "llm = FakeListLLM(responses=responses)\n",
+ "\n",
+ "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
+ "\n",
+ "chain = (\n",
+ " prompt \n",
+ " | comprehend_moderation \n",
+ " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
+ " | llm_chain \n",
+ " | { \"input\": lambda x: x['text'] } \n",
+ " | comprehend_moderation \n",
+ ")\n",
+ "\n",
+ "try:\n",
+ " response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n",
+ "except ModerationPiiError as e:\n",
+ " print(e.message)\n",
+ "else:\n",
+ " print(response['output'])\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6da25d96-0d96-4c01-94ae-a2ead17f10aa",
+ "metadata": {},
+ "source": [
+ "## Using `moderation_config` to customize your moderation\n",
+ "---"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bfd550e7-5012-41fa-9546-8b78ddf1c673",
+ "metadata": {},
+ "source": [
+ "Use Amazon Comprehend Moderation with a configuration to control what moderations you wish to perform and what actions should be taken for each of them. There are three different moderations that happen when no configuration is passed as demonstrated above. These moderations are:\n",
+ "\n",
+ "- PII (Personally Identifiable Information) checks \n",
+ "- Toxicity content detection\n",
+ "- Intention detection\n",
+ "\n",
+ "Here is an example of a moderation config."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d6e8900a-44ef-4967-bde8-b88af282139d",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_experimental.comprehend_moderation import BaseModerationActions, BaseModerationFilters\n",
+ "\n",
+ "moderation_config = { \n",
+ " \"filters\":[ \n",
+ " BaseModerationFilters.PII, \n",
+ " BaseModerationFilters.TOXICITY,\n",
+ " BaseModerationFilters.INTENT\n",
+ " ],\n",
+ " \"pii\":{ \n",
+ " \"action\": BaseModerationActions.ALLOW, \n",
+ " \"threshold\":0.5, \n",
+ " \"labels\":[\"SSN\"],\n",
+ " \"mask_character\": \"X\"\n",
+ " },\n",
+ " \"toxicity\":{ \n",
+ " \"action\": BaseModerationActions.STOP, \n",
+ " \"threshold\":0.5\n",
+ " },\n",
+ " \"intent\":{ \n",
+ " \"action\": BaseModerationActions.STOP, \n",
+ " \"threshold\":0.5\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3634376b-5938-43df-9ed6-70ca7e99290f",
+ "metadata": {},
+ "source": [
+ "At the core of the configuration you have three filters specified in the `filters` key:\n",
+ "\n",
+ "1. `BaseModerationFilters.PII`\n",
+ "2. `BaseModerationFilters.TOXICITY`\n",
+ "3. `BaseModerationFilters.INTENT`\n",
+ "\n",
+ "And an `action` key that defines two possible actions for each moderation function:\n",
+ "\n",
+ "1. `BaseModerationActions.ALLOW` - `allows` the prompt to pass through but masks detected PII in case of PII check. The default behavior is to run and redact all PII entities. If there is an entity specified in the `labels` field, then only those entities will go through the PII check and masked.\n",
+ "2. `BaseModerationActions.STOP` - `stops` the prompt from passing through to the next step in case any PII, Toxicity, or incorrect Intent is detected. The action of `BaseModerationActions.STOP` will raise a Python `Exception` essentially stopping the chain in progress.\n",
+ "\n",
+ "Using the configuration in the previous cell will perform PII checks and will allow the prompt to pass through however it will mask any SSN numbers present in either the prompt or the LLM output.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3a4f7e65-f733-4863-ae6d-34c9faffd849",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "comp_moderation_with_config = AmazonComprehendModerationChain(\n",
+ " moderation_config=moderation_config, #specify the configuration\n",
+ " client=comprehend_client, #optionally pass the Boto3 Client\n",
+ " verbose=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a25e6f93-765b-4f99-8c1c-929157dbd4aa",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "template = \"\"\"Question: {question}\n",
+ "\n",
+ "Answer:\"\"\"\n",
+ "\n",
+ "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
+ "\n",
+ "responses = [\n",
+ " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
+ " \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
+ "]\n",
+ "llm = FakeListLLM(responses=responses)\n",
+ "\n",
+ "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
+ "\n",
+ "chain = ( \n",
+ " prompt \n",
+ " | comp_moderation_with_config \n",
+ " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
+ " | llm_chain \n",
+ " | { \"input\": lambda x: x['text'] } \n",
+ " | comp_moderation_with_config \n",
+ ")\n",
+ "\n",
+ "try:\n",
+ " response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n",
+ "except Exception as e:\n",
+ " print(str(e))\n",
+ "else:\n",
+ " print(response['output'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ba890681-feeb-43ca-a0d5-9c11d2d9de3e",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## Unique ID, and Moderation Callbacks\n",
+ "---\n",
+ "\n",
+ "When Amazon Comprehend moderation action is specified as `STOP`, the chain will raise one of the following exceptions-\n",
+ " - `ModerationPiiError`, for PII checks\n",
+ " - `ModerationToxicityError`, for Toxicity checks \n",
+ " - `ModerationIntentionError` for Intent checks\n",
+ "\n",
+ "In addition to the moderation configuration, the `AmazonComprehendModerationChain` can also be initialized with the following parameters\n",
+ "\n",
+ "- `unique_id` [Optional] a string parameter. This parameter can be used to pass any string value or ID. For example, in a chat application you may want to keep track of abusive users, in this case you can pass the user's username/email id etc. This defaults to `None`.\n",
+ "\n",
+ "- `moderation_callback` [Optional] the `BaseModerationCallbackHandler` that will be called asynchronously (non-blocking to the chain). Callback functions are useful when you want to perform additional actions when the moderation functions are executed, for example logging into a database, or writing a log file. You can override three functions by subclassing `BaseModerationCallbackHandler` - `on_after_pii()`, `on_after_toxicity()`, and `on_after_intent()`. Note that all three functions must be `async` functions. These callback functions receive two arguments:\n",
+ " - `moderation_beacon` a dictionary that will contain information about the moderation function, the full response from Amazon Comprehend model, a unique chain id, the moderation status, and the input string which was validated. The dictionary is of the following schema-\n",
+ " \n",
+ " ```\n",
+ " { \n",
+ " 'moderation_chain_id': 'xxx-xxx-xxx', # Unique chain ID\n",
+ " 'moderation_type': 'Toxicity' | 'PII' | 'Intent', \n",
+ " 'moderation_status': 'LABELS_FOUND' | 'LABELS_NOT_FOUND',\n",
+ " 'moderation_input': 'A sample SSN number looks like this 123-456-7890. Can you give me some more samples?',\n",
+ " 'moderation_output': {...} #Full Amazon Comprehend PII, Toxicity, or Intent Model Output\n",
+ " }\n",
+ " ```\n",
+ " \n",
+ " - `unique_id` if passed to the `AmazonComprehendModerationChain`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c178835-0264-4ac6-aef4-091d2993d06c",
+ "metadata": {},
+ "source": [
+ "
NOTE: moderation_callback
is different from LangChain Chain Callbacks. You can still use LangChain Chain callbacks with
AmazonComprehendModerationChain
via the callbacks parameter. Example:
\n",
+ "
\n",
+ "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
+ "comp_moderation_with_config = AmazonComprehendModerationChain(verbose=True, callbacks=[StdOutCallbackHandler()])\n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0ec38536-8cc9-408e-860b-e4a439283643",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_experimental.comprehend_moderation import BaseModerationCallbackHandler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1be744c7-3f99-4165-bf7f-9c5c249bbb53",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Define callback handlers by subclassing BaseModerationCallbackHandler\n",
+ "\n",
+ "class MyModCallback(BaseModerationCallbackHandler):\n",
+ " \n",
+ " async def on_after_pii(self, output_beacon, unique_id):\n",
+ " import json\n",
+ " moderation_type = output_beacon['moderation_type']\n",
+ " chain_id = output_beacon['moderation_chain_id']\n",
+ " with open(f'output-{moderation_type}-{chain_id}.json', 'w') as file:\n",
+ " data = { 'beacon_data': output_beacon, 'unique_id': unique_id }\n",
+ " json.dump(data, file)\n",
+ " \n",
+ " '''\n",
+ " async def on_after_toxicity(self, output_beacon, unique_id):\n",
+ " pass\n",
+ " \n",
+ " async def on_after_intent(self, output_beacon, unique_id):\n",
+ " pass\n",
+ " '''\n",
+ " \n",
+ "\n",
+ "my_callback = MyModCallback()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "362a3fe0-f09f-411e-9df1-d79b3e87510c",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "moderation_config = { \n",
+ " \"filters\": [ \n",
+ " BaseModerationFilters.PII, \n",
+ " BaseModerationFilters.TOXICITY\n",
+ " ],\n",
+ " \"pii\":{ \n",
+ " \"action\": BaseModerationActions.STOP, \n",
+ " \"threshold\":0.5, \n",
+ " \"labels\":[\"SSN\"], \n",
+ " \"mask_character\": \"X\" \n",
+ " },\n",
+ " \"toxicity\":{ \n",
+ " \"action\": BaseModerationActions.STOP, \n",
+ " \"threshold\":0.5 \n",
+ " }\n",
+ "}\n",
+ "\n",
+ "comp_moderation_with_config = AmazonComprehendModerationChain(\n",
+ " moderation_config=moderation_config, # specify the configuration\n",
+ " client=comprehend_client, # optionally pass the Boto3 Client\n",
+ " force_base_exception=True, # Force BaseModerationError\n",
+ " unique_id='john.doe@email.com', # A unique ID\n",
+ " moderation_callback=my_callback, # BaseModerationCallbackHandler\n",
+ " verbose=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2af07937-67ea-4738-8343-c73d4d28c2cc",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain import PromptTemplate, LLMChain\n",
+ "from langchain.llms.fake import FakeListLLM\n",
+ "\n",
+ "template = \"\"\"Question: {question}\n",
+ "\n",
+ "Answer:\"\"\"\n",
+ "\n",
+ "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
+ "\n",
+ "responses = [\n",
+ " \"Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.\", \n",
+ " \"Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here.\"\n",
+ "]\n",
+ "\n",
+ "llm = FakeListLLM(responses=responses)\n",
+ "\n",
+ "llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
+ "\n",
+ "chain = (\n",
+ " prompt \n",
+ " | comp_moderation_with_config \n",
+ " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
+ " | llm_chain \n",
+ " | { \"input\": lambda x: x['text'] } \n",
+ " | comp_moderation_with_config \n",
+ ") \n",
+ "\n",
+ "try:\n",
+ " response = chain.invoke({\"question\": \"A sample SSN number looks like this 123-456-7890. Can you give me some more samples?\"})\n",
+ "except Exception as e:\n",
+ " print(str(e))\n",
+ "else:\n",
+ " print(response['output'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "706454b2-2efa-4d41-abc8-ccf2b4e87822",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "## `moderation_config` and moderation execution order\n",
+ "---\n",
+ "\n",
+ "If `AmazonComprehendModerationChain` is not initialized with any `moderation_config` then the default action is `STOP` and default order of moderation check is as follows.\n",
+ "\n",
+ "```\n",
+ "AmazonComprehendModerationChain\n",
+ "│\n",
+ "└──Check PII with Stop Action\n",
+ " ├── Callback (if available)\n",
+ " ├── Label Found ⟶ [Error Stop]\n",
+ " └── No Label Found \n",
+ " └──Check Toxicity with Stop Action\n",
+ " ├── Callback (if available)\n",
+ " ├── Label Found ⟶ [Error Stop]\n",
+ " └── No Label Found\n",
+ " └──Check Intent with Stop Action\n",
+ " ├── Callback (if available)\n",
+ " ├── Label Found ⟶ [Error Stop]\n",
+ " └── No Label Found\n",
+ " └── Return Prompt\n",
+ "```\n",
+ "\n",
+ "If any of the check raises exception then the subsequent checks will not be performed. If a `callback` is provided in this case, then it will be called for each of the checks that have been performed. For example, in the case above, if the Chain fails due to presence of PII then the Toxicity and Intent checks will not be performed.\n",
+ "\n",
+ "You can override the execution order by passing `moderation_config` and simply specifying the desired order in the `filters` key of the configuration. In case you use `moderation_config` then the order of the checks as specified in the `filters` key will be maintained. For example, in the configuration below, first Toxicity check will be performed, then PII, and finally Intent validation will be performed. In this case, `AmazonComprehendModerationChain` will perform the desired checks in the specified order with default values of each model `kwargs`.\n",
+ "\n",
+ "```python\n",
+ "moderation_config = { \n",
+ " \"filters\":[ BaseModerationFilters.TOXICITY, \n",
+ " BaseModerationFilters.PII, \n",
+ " BaseModerationFilters.INTENT]\n",
+ " }\n",
+ "```\n",
+ "\n",
+ "Model `kwargs` are specified by the `pii`, `toxicity`, and `intent` keys within the `moderation_config` dictionary. For example, in the `moderation_config` below, the default order of moderation is overriden and the `pii` & `toxicity` model `kwargs` have been overriden. For `intent` the chain's default `kwargs` will be used.\n",
+ "\n",
+ "```python\n",
+ " moderation_config = { \n",
+ " \"filters\":[ BaseModerationFilters.TOXICITY, \n",
+ " BaseModerationFilters.PII, \n",
+ " BaseModerationFilters.INTENT],\n",
+ " \"pii\":{ \"action\": BaseModerationActions.ALLOW, \n",
+ " \"threshold\":0.5, \n",
+ " \"labels\":[\"SSN\"], \n",
+ " \"mask_character\": \"X\" },\n",
+ " \"toxicity\":{ \"action\": BaseModerationActions.STOP, \n",
+ " \"threshold\":0.5 }\n",
+ " }\n",
+ "```\n",
+ "\n",
+ "1. For a list of PII labels see Amazon Comprehend Universal PII entity types - https://docs.aws.amazon.com/comprehend/latest/dg/how-pii.html#how-pii-types\n",
+ "2. Following are the list of available Toxicity labels-\n",
+ " - `HATE_SPEECH`: Speech that criticizes, insults, denounces or dehumanizes a person or a group on the basis of an identity, be it race, ethnicity, gender identity, religion, sexual orientation, ability, national origin, or another identity-group.\n",
+ " - `GRAPHIC`: Speech that uses visually descriptive, detailed and unpleasantly vivid imagery is considered as graphic. Such language is often made verbose so as to amplify an insult, discomfort or harm to the recipient.\n",
+ " - `HARASSMENT_OR_ABUSE`: Speech that imposes disruptive power dynamics between the speaker and hearer, regardless of intent, seeks to affect the psychological well-being of the recipient, or objectifies a person should be classified as Harassment.\n",
+ " - `SEXUAL`: Speech that indicates sexual interest, activity or arousal by using direct or indirect references to body parts or physical traits or sex is considered as toxic with toxicityType \"sexual\". \n",
+ " - `VIOLENCE_OR_THREAT`: Speech that includes threats which seek to inflict pain, injury or hostility towards a person or group.\n",
+ " - `INSULT`: Speech that includes demeaning, humiliating, mocking, insulting, or belittling language.\n",
+ " - `PROFANITY`: Speech that contains words, phrases or acronyms that are impolite, vulgar, or offensive is considered as profane.\n",
+ "3. For a list of Intent labels refer to documentation [link here]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "78905aec-55ae-4fc3-a23b-8a69bd1e33f2",
+ "metadata": {},
+ "source": [
+ "# Examples\n",
+ "---\n",
+ "\n",
+ "## With HuggingFace Hub Models\n",
+ "\n",
+ "Get your API Key from Huggingface hub - https://huggingface.co/docs/api-inference/quicktour#get-your-api-token"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "359b9627-769b-46ce-8be2-c8a5cf7728ba",
+ "metadata": {
+ "scrolled": true,
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%pip install huggingface_hub"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "41b7ea98-ad16-4454-8f12-c03c17113a86",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%env HUGGINGFACEHUB_API_TOKEN=\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b235427-cc06-4c07-874b-1f67c2d1f924",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# See https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads for some other options\n",
+ "repo_id = \"google/flan-t5-xxl\" \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9d86e256-34fb-4c8e-8092-1a4f863a5c96",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from langchain import HuggingFaceHub\n",
+ "from langchain import PromptTemplate, LLMChain\n",
+ "\n",
+ "template = \"\"\"Question: {question}\n",
+ "\n",
+ "Answer:\"\"\"\n",
+ "\n",
+ "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
+ "\n",
+ "llm = HuggingFaceHub(\n",
+ " repo_id=repo_id, model_kwargs={\"temperature\": 0.5, \"max_length\": 256}\n",
+ ")\n",
+ "llm_chain = LLMChain(prompt=prompt, llm=llm)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ad603796-ad8b-4599-9022-a486f1c1b89a",
+ "metadata": {},
+ "source": [
+ "Create a configuration and initialize an Amazon Comprehend Moderation chain"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "decc3409-5be5-433d-b6da-38b9e5c5ee3f",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "moderation_config = { \n",
+ " \"filters\":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY, BaseModerationFilters.INTENT ],\n",
+ " \"pii\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5, \"labels\":[\"SSN\",\"CREDIT_DEBIT_NUMBER\"], \"mask_character\": \"X\"},\n",
+ " \"toxicity\":{\"action\": BaseModerationActions.STOP, \"threshold\":0.5},\n",
+ " \"intent\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5,},\n",
+ " }\n",
+ "\n",
+ "# without any callback\n",
+ "amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
+ " client=comprehend_client,\n",
+ " verbose=True)\n",
+ "\n",
+ "# with callback\n",
+ "amazon_comp_moderation_out = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
+ " client=comprehend_client,\n",
+ " moderation_callback=my_callback,\n",
+ " verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1256bc8-1321-4624-9e8a-a2d4a8df59bf",
+ "metadata": {},
+ "source": [
+ "The `moderation_config` will now prevent any inputs and model outputs containing obscene words or sentences, bad intent, or PII with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0337becc-7c3c-483e-a55c-a225226cb9ee",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "chain = (\n",
+ " prompt \n",
+ " | amazon_comp_moderation \n",
+ " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
+ " | llm_chain \n",
+ " | { \"input\": lambda x: x['text'] } \n",
+ " | amazon_comp_moderation_out\n",
+ ")\n",
+ "\n",
+ "try:\n",
+ " response = chain.invoke({\"question\": \"My AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0008 has 24$ due by July 31st. Can you give me some more credit car number samples?\"})\n",
+ "except Exception as e:\n",
+ " print(str(e))\n",
+ "else:\n",
+ " print(response['output'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ee52c7b8-6526-4f68-a2b3-b5ad3cf82489",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "---\n",
+ "## With Amazon SageMaker Jumpstart\n",
+ "\n",
+ "The exmaple below shows how to use Amazon Comprehend Moderation chain with an Amazon SageMaker Jumpstart hosted LLM. You should have an Amazon SageMaker Jumpstart hosted LLM endpoint within your AWS Account. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cd49d075-bc23-4ab8-a92c-0ddbbc436c30",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "endpoint_name = \"\" # replace with your SageMaker Endpoint name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5978a5e6-667d-4926-842c-d965f88e5640",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from langchain import SagemakerEndpoint\n",
+ "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
+ "from langchain.chains import LLMChain\n",
+ "from langchain.prompts import load_prompt, PromptTemplate\n",
+ "import json\n",
+ "\n",
+ "class ContentHandler(LLMContentHandler):\n",
+ " content_type = \"application/json\"\n",
+ " accepts = \"application/json\"\n",
+ "\n",
+ " def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:\n",
+ " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
+ " return input_str.encode('utf-8')\n",
+ " \n",
+ " def transform_output(self, output: bytes) -> str:\n",
+ " response_json = json.loads(output.read().decode(\"utf-8\"))\n",
+ " return response_json['generated_texts'][0]\n",
+ "\n",
+ "content_handler = ContentHandler()\n",
+ "\n",
+ "#prompt template for input text\n",
+ "llm_prompt = PromptTemplate(input_variables=[\"input_text\"], template=\"{input_text}\")\n",
+ "\n",
+ "llm_chain = LLMChain(\n",
+ " llm=SagemakerEndpoint(\n",
+ " endpoint_name=endpoint_name, \n",
+ " region_name='us-east-1',\n",
+ " model_kwargs={\"temperature\":0.97,\n",
+ " \"max_length\": 200,\n",
+ " \"num_return_sequences\": 3,\n",
+ " \"top_k\": 50,\n",
+ " \"top_p\": 0.95,\n",
+ " \"do_sample\": True},\n",
+ " content_handler=content_handler\n",
+ " ),\n",
+ " prompt=llm_prompt\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d577b036-99a4-47fe-9a8e-4a34aa4cd88d",
+ "metadata": {},
+ "source": [
+ "Create a configuration and initialize an Amazon Comprehend Moderation chain"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "859da135-94d3-4a9c-970e-a873913592e2",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "moderation_config = { \n",
+ " \"filters\":[ BaseModerationFilters.PII, BaseModerationFilters.TOXICITY ],\n",
+ " \"pii\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5, \"labels\":[\"SSN\"], \"mask_character\": \"X\"},\n",
+ " \"toxicity\":{\"action\": BaseModerationActions.STOP, \"threshold\":0.5},\n",
+ " \"intent\":{\"action\": BaseModerationActions.ALLOW, \"threshold\":0.5,},\n",
+ " }\n",
+ "\n",
+ "amazon_comp_moderation = AmazonComprehendModerationChain(moderation_config=moderation_config, \n",
+ " client=comprehend_client ,\n",
+ " verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9abb191f-7a96-4077-8c30-b9ddc225bd6b",
+ "metadata": {},
+ "source": [
+ "The `moderation_config` will now prevent any inputs and model outputs containing obscene words or sentences, bad intent, or Pii with entities other than SSN with score above threshold or 0.5 or 50%. If it finds Pii entities - SSN - it will redact them before allowing the call to proceed. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6db5aa2a-9c00-42a0-8e24-c5ba39994f7d",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "chain = (\n",
+ " prompt \n",
+ " | amazon_comp_moderation \n",
+ " | {llm_chain.input_keys[0]: lambda x: x['output'] } \n",
+ " | llm_chain \n",
+ " | { \"input\": lambda x: x['text'] } \n",
+ " | amazon_comp_moderation \n",
+ ")\n",
+ "\n",
+ "try:\n",
+ " response = chain.invoke({\"question\": \"My AnyCompany Financial Services, LLC credit card account 1111-0000-1111-0008 has 24$ due by July 31st. Can you give me some more samples?\"})\n",
+ "except Exception as e:\n",
+ " print(str(e))\n",
+ "else:\n",
+ " print(response['output'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7fdfedf9-1a0a-4a9f-a6b0-d9ed2dbaa5ad",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "availableInstances": [
+ {
+ "_defaultOrder": 0,
+ "_isFastLaunch": true,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 4,
+ "name": "ml.t3.medium",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 1,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.t3.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 2,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.t3.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 3,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.t3.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 4,
+ "_isFastLaunch": true,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.m5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 5,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.m5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 6,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.m5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 7,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.m5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 8,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.m5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 9,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.m5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 10,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.m5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 11,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.m5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 12,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.m5d.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 13,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.m5d.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 14,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.m5d.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 15,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.m5d.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 16,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.m5d.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 17,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.m5d.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 18,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.m5d.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 19,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.m5d.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 20,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": true,
+ "memoryGiB": 0,
+ "name": "ml.geospatial.interactive",
+ "supportedImageNames": [
+ "sagemaker-geospatial-v1-0"
+ ],
+ "vcpuNum": 0
+ },
+ {
+ "_defaultOrder": 21,
+ "_isFastLaunch": true,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 4,
+ "name": "ml.c5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 22,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.c5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 23,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.c5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 24,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.c5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 25,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 72,
+ "name": "ml.c5.9xlarge",
+ "vcpuNum": 36
+ },
+ {
+ "_defaultOrder": 26,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 96,
+ "name": "ml.c5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 27,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 144,
+ "name": "ml.c5.18xlarge",
+ "vcpuNum": 72
+ },
+ {
+ "_defaultOrder": 28,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.c5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 29,
+ "_isFastLaunch": true,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.g4dn.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 30,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.g4dn.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 31,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.g4dn.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 32,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.g4dn.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 33,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.g4dn.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 34,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.g4dn.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 35,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 61,
+ "name": "ml.p3.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 36,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 244,
+ "name": "ml.p3.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 37,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 488,
+ "name": "ml.p3.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 38,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.p3dn.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 39,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.r5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 40,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.r5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 41,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.r5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 42,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.r5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 43,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.r5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 44,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.r5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 45,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 512,
+ "name": "ml.r5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 46,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.r5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 47,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.g5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 48,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.g5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 49,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.g5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 50,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.g5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 51,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.g5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 52,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.g5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 53,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.g5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 54,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.g5.48xlarge",
+ "vcpuNum": 192
+ },
+ {
+ "_defaultOrder": 55,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 1152,
+ "name": "ml.p4d.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 56,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 1152,
+ "name": "ml.p4de.24xlarge",
+ "vcpuNum": 96
+ }
+ ],
+ "instance_type": "ml.t3.medium",
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py b/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py
new file mode 100644
index 00000000000..5e4a2686317
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/__init__.py
@@ -0,0 +1,25 @@
+from langchain_experimental.comprehend_moderation.amazon_comprehend_moderation import (
+ AmazonComprehendModerationChain,
+)
+from langchain_experimental.comprehend_moderation.base_moderation import BaseModeration
+from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
+ BaseModerationCallbackHandler,
+)
+from langchain_experimental.comprehend_moderation.base_moderation_enums import (
+ BaseModerationActions,
+ BaseModerationFilters,
+)
+from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
+from langchain_experimental.comprehend_moderation.pii import ComprehendPII
+from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
+
+__all__ = [
+ "BaseModeration",
+ "BaseModerationActions",
+ "BaseModerationFilters",
+ "ComprehendPII",
+ "ComprehendIntent",
+ "ComprehendToxicity",
+ "BaseModerationCallbackHandler",
+ "AmazonComprehendModerationChain",
+]
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py b/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py
new file mode 100644
index 00000000000..d00520e6270
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/amazon_comprehend_moderation.py
@@ -0,0 +1,184 @@
+from typing import Any, Dict, List, Optional
+
+from langchain.callbacks.manager import CallbackManagerForChainRun
+from langchain.chains.base import Chain
+
+from langchain_experimental.comprehend_moderation.base_moderation import (
+ BaseModeration,
+)
+from langchain_experimental.comprehend_moderation.base_moderation_callbacks import (
+ BaseModerationCallbackHandler,
+)
+from langchain_experimental.pydantic_v1 import root_validator
+
+
+class AmazonComprehendModerationChain(Chain):
+ """A subclass of Chain, designed to apply moderation to LLMs."""
+
+ output_key: str = "output" #: :meta private:
+ """Key used to fetch/store the output in data containers. Defaults to `output`"""
+
+ input_key: str = "input" #: :meta private:
+ """Key used to fetch/store the input in data containers. Defaults to `input`"""
+
+ moderation_config: Optional[Dict[str, Any]] = None
+ """Configuration settings for moderation"""
+
+ client: Optional[Any]
+ """boto3 client object for connection to Amazon Comprehend"""
+
+ region_name: Optional[str] = None
+ """The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
+ or region specified in ~/.aws/config in case it is not provided here.
+ """
+
+ credentials_profile_name: Optional[str] = None
+ """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
+ has either access keys or role information specified.
+ If not specified, the default credential profile or, if on an EC2 instance,
+ credentials from IMDS will be used.
+ See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
+ """
+
+ moderation_callback: Optional[BaseModerationCallbackHandler] = None
+ """Callback handler for moderation, this is different
+ from regular callbacks which can be used in addition to this."""
+
+ unique_id: Optional[str] = None
+ """A unique id that can be used to identify or group a user or session"""
+
+ @root_validator(pre=True)
+ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Creates an Amazon Comprehend client
+
+ Args:
+ values (Dict[str, Any]): A dictionary containing configuration values.
+
+ Returns:
+ Dict[str, Any]: A dictionary with the updated configuration values,
+ including the Amazon Comprehend client.
+
+ Raises:
+ ModuleNotFoundError: If the 'boto3' package is not installed.
+ ValueError: If there is an issue importing 'boto3' or loading
+ AWS credentials.
+
+ Example:
+ .. code-block:: python
+
+ config = {
+ "credentials_profile_name": "my-profile",
+ "region_name": "us-west-2"
+ }
+ updated_config = create_client(config)
+ comprehend_client = updated_config["client"]
+ """
+
+ if values.get("client") is not None:
+ return values
+ try:
+ import boto3
+
+ if values.get("credentials_profile_name"):
+ session = boto3.Session(profile_name=values["credentials_profile_name"])
+ else:
+ # use default credentials
+ session = boto3.Session()
+
+ client_params = {}
+ if values.get("region_name"):
+ client_params["region_name"] = values["region_name"]
+
+ values["client"] = session.client("comprehend", **client_params)
+
+ return values
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import boto3 python package. "
+ "Please install it with `pip install boto3`."
+ )
+ except Exception as e:
+ raise ValueError(
+ "Could not load credentials to authenticate with AWS client. "
+ "Please check that credentials in the specified "
+ "profile name are valid."
+ ) from e
+
+ @property
+ def output_keys(self) -> List[str]:
+ """
+ Returns a list of output keys.
+
+ This method defines the output keys that will be used to access the output
+ values produced by the chain or function. It ensures that the specified keys
+ are available to access the outputs.
+
+ Returns:
+ List[str]: A list of output keys.
+
+ Note:
+ This method is considered private and may not be intended for direct
+ external use.
+
+ """
+ return [self.output_key]
+
+ @property
+ def input_keys(self) -> List[str]:
+ """
+ Returns a list of input keys expected by the prompt.
+
+ This method defines the input keys that the prompt expects in order to perform
+ its processing. It ensures that the specified keys are available for providing
+ input to the prompt.
+
+ Returns:
+ List[str]: A list of input keys.
+
+ Note:
+ This method is considered private and may not be intended for direct
+ external use.
+ """
+ return [self.input_key]
+
+ def _call(
+ self,
+ inputs: Dict[str, Any],
+ run_manager: Optional[CallbackManagerForChainRun] = None,
+ ) -> Dict[str, str]:
+ """
+ Executes the moderation process on the input text and returns the processed
+ output.
+
+ This internal method performs the moderation process on the input text. It
+ converts the input prompt value to plain text, applies the specified filters,
+ and then converts the filtered output back to a suitable prompt value object.
+ Additionally, it provides the option to log information about the run using
+ the provided `run_manager`.
+
+ Args:
+ inputs: A dictionary containing input values
+ run_manager: A run manager to handle run-related events. Default is None
+
+ Returns:
+ Dict[str, str]: A dictionary containing the processed output of the
+ moderation process.
+
+ Raises:
+ ValueError: If there is an error during the moderation process
+ """
+
+ if run_manager:
+ run_manager.on_text("Running AmazonComprehendModerationChain...\n")
+
+ moderation = BaseModeration(
+ client=self.client,
+ config=self.moderation_config,
+ moderation_callback=self.moderation_callback,
+ unique_id=self.unique_id,
+ run_manager=run_manager,
+ )
+ response = moderation.moderate(prompt=inputs[self.input_keys[0]])
+
+ return {self.output_key: response}
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py
new file mode 100644
index 00000000000..c639112b959
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation.py
@@ -0,0 +1,176 @@
+import uuid
+from typing import Any, Callable, Dict, Optional
+
+from langchain.callbacks.manager import CallbackManagerForChainRun
+from langchain.prompts.base import StringPromptValue
+from langchain.prompts.chat import ChatPromptValue
+from langchain.schema import AIMessage, HumanMessage
+
+from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
+from langchain_experimental.comprehend_moderation.pii import ComprehendPII
+from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity
+
+
+class BaseModeration:
+ def __init__(
+ self,
+ client: Any,
+ config: Optional[Dict[str, Any]] = None,
+ moderation_callback: Optional[Any] = None,
+ unique_id: Optional[str] = None,
+ run_manager: Optional[CallbackManagerForChainRun] = None,
+ ):
+ self.client = client
+ self.config = config
+ self.moderation_callback = moderation_callback
+ self.unique_id = unique_id
+ self.chat_message_index = 0
+ self.run_manager = run_manager
+ self.chain_id = str(uuid.uuid4())
+
+ def _convert_prompt_to_text(self, prompt: Any) -> str:
+ input_text = str()
+
+ if isinstance(prompt, StringPromptValue):
+ input_text = prompt.text
+ elif isinstance(prompt, str):
+ input_text = prompt
+ elif isinstance(prompt, ChatPromptValue):
+ """
+ We will just check the last message in the message Chain of a
+ ChatPromptTemplate. The typical chronology is
+ SystemMessage > HumanMessage > AIMessage and so on. However assuming
+ that with every chat the chain is invoked we will only check the last
+ message. This is assuming that all previous messages have been checked
+ already. Only HumanMessage and AIMessage will be checked. We can perhaps
+ loop through and take advantage of the additional_kwargs property in the
+ HumanMessage and AIMessage schema to mark messages that have been moderated.
+ However that means that this class could generate multiple text chunks
+ and moderate() logics would need to be updated. This also means some
+ complexity in re-constructing the prompt while keeping the messages in
+ sequence.
+ """
+ message = prompt.messages[-1]
+ self.chat_message_index = len(prompt.messages) - 1
+ if isinstance(message, HumanMessage):
+ input_text = message.content
+
+ if isinstance(message, AIMessage):
+ input_text = message.content
+ else:
+ raise ValueError(
+ f"Invalid input type {type(input)}. "
+ "Must be a PromptValue, str, or list of BaseMessages."
+ )
+ return input_text
+
+ def _convert_text_to_prompt(self, prompt: Any, text: str) -> Any:
+ if isinstance(prompt, StringPromptValue):
+ return StringPromptValue(text=text)
+ elif isinstance(prompt, str):
+ return text
+ elif isinstance(prompt, ChatPromptValue):
+ messages = prompt.messages
+ message = messages[self.chat_message_index]
+
+ if isinstance(message, HumanMessage):
+ messages[self.chat_message_index] = HumanMessage(
+ content=text,
+ example=message.example,
+ additional_kwargs=message.additional_kwargs,
+ )
+ if isinstance(message, AIMessage):
+ messages[self.chat_message_index] = AIMessage(
+ content=text,
+ example=message.example,
+ additional_kwargs=message.additional_kwargs,
+ )
+ return ChatPromptValue(messages=messages)
+ else:
+ raise ValueError(
+ f"Invalid input type {type(input)}. "
+ "Must be a PromptValue, str, or list of BaseMessages."
+ )
+
+ def _moderation_class(self, moderation_class: Any) -> Callable:
+ return moderation_class(
+ client=self.client,
+ callback=self.moderation_callback,
+ unique_id=self.unique_id,
+ chain_id=self.chain_id,
+ ).validate
+
+ def _log_message_for_verbose(self, message: str) -> None:
+ if self.run_manager:
+ self.run_manager.on_text(message)
+
+ def moderate(self, prompt: Any) -> str:
+ from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
+ ModerationIntentionError,
+ ModerationPiiError,
+ ModerationToxicityError,
+ )
+
+ try:
+ # convert prompt to text
+ input_text = self._convert_prompt_to_text(prompt=prompt)
+ output_text = str()
+ # perform moderation
+ if self.config is None:
+ # In absence of config Action will default to STOP only
+ self._log_message_for_verbose("Running pii validation...\n")
+ pii_validate = self._moderation_class(moderation_class=ComprehendPII)
+ output_text = pii_validate(prompt_value=input_text)
+
+ self._log_message_for_verbose("Running toxicity validation...\n")
+ toxicity_validate = self._moderation_class(
+ moderation_class=ComprehendToxicity
+ )
+ output_text = toxicity_validate(prompt_value=output_text)
+
+ self._log_message_for_verbose("Running intent validation...\n")
+ intent_validate = self._moderation_class(
+ moderation_class=ComprehendIntent
+ )
+ output_text = intent_validate(prompt_value=output_text)
+ else:
+ filter_functions = {
+ "pii": ComprehendPII,
+ "toxicity": ComprehendToxicity,
+ "intent": ComprehendIntent,
+ }
+ filters = self.config["filters"]
+ for _filter in filters:
+ filter_name = f"{_filter}"
+ if filter_name in filter_functions:
+ self._log_message_for_verbose(
+ f"Running {filter_name} Validation...\n"
+ )
+ validation_fn = self._moderation_class(
+ moderation_class=filter_functions[filter_name]
+ )
+ input_text = input_text if not output_text else output_text
+ output_text = validation_fn(
+ prompt_value=input_text,
+ config=self.config[filter_name]
+ if filter_name in self.config
+ else None,
+ )
+ # convert text to prompt and return
+ return self._convert_text_to_prompt(prompt=prompt, text=output_text)
+
+ except ModerationPiiError as e:
+ self._log_message_for_verbose(f"Found PII content..stopping..\n{str(e)}\n")
+ raise e
+ except ModerationToxicityError as e:
+ self._log_message_for_verbose(
+ f"Found Toxic content..stopping..\n{str(e)}\n"
+ )
+ raise e
+ except ModerationIntentionError as e:
+ self._log_message_for_verbose(
+ f"Found Harmful intention..stopping..\n{str(e)}\n"
+ )
+ raise e
+ except Exception as e:
+ raise e
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py
new file mode 100644
index 00000000000..d7fcd76a106
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_callbacks.py
@@ -0,0 +1,64 @@
+from typing import Any, Callable, Dict
+
+
+class BaseModerationCallbackHandler:
+ def __init__(self) -> None:
+ if (
+ self._is_method_unchanged(
+ BaseModerationCallbackHandler.on_after_pii, self.on_after_pii
+ )
+ and self._is_method_unchanged(
+ BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
+ )
+ and self._is_method_unchanged(
+ BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
+ )
+ ):
+ raise NotImplementedError(
+ "Subclasses must override at least one of on_after_pii(), "
+ "on_after_toxicity(), or on_after_intent() functions."
+ )
+
+ def _is_method_unchanged(
+ self, base_method: Callable, derived_method: Callable
+ ) -> bool:
+ return base_method.__qualname__ == derived_method.__qualname__
+
+ async def on_after_pii(
+ self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
+ ) -> None:
+ """Run after PII validation is complete."""
+ raise NotImplementedError("Subclasses should implement this async method.")
+
+ async def on_after_toxicity(
+ self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
+ ) -> None:
+ """Run after Toxicity validation is complete."""
+ raise NotImplementedError("Subclasses should implement this async method.")
+
+ async def on_after_intent(
+ self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
+ ) -> None:
+ """Run after Toxicity validation is complete."""
+ raise NotImplementedError("Subclasses should implement this async method.")
+
+ @property
+ def pii_callback(self) -> bool:
+ return (
+ self.on_after_pii.__func__ # type: ignore
+ is not BaseModerationCallbackHandler.on_after_pii
+ )
+
+ @property
+ def toxicity_callback(self) -> bool:
+ return (
+ self.on_after_toxicity.__func__ # type: ignore
+ is not BaseModerationCallbackHandler.on_after_toxicity
+ )
+
+ @property
+ def intent_callback(self) -> bool:
+ return (
+ self.on_after_intent.__func__ # type: ignore
+ is not BaseModerationCallbackHandler.on_after_intent
+ )
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py
new file mode 100644
index 00000000000..aec629ebcc8
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_enums.py
@@ -0,0 +1,12 @@
+from enum import Enum
+
+
+class BaseModerationActions(Enum):
+ STOP = 1
+ ALLOW = 2
+
+
+class BaseModerationFilters(str, Enum):
+ PII = "pii"
+ TOXICITY = "toxicity"
+ INTENT = "intent"
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_exceptions.py b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_exceptions.py
new file mode 100644
index 00000000000..74b3971df04
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/base_moderation_exceptions.py
@@ -0,0 +1,43 @@
+class ModerationPiiError(Exception):
+ """Exception raised if PII entities are detected.
+
+ Attributes:
+ message -- explanation of the error
+ """
+
+ def __init__(
+ self, message: str = "The prompt contains PII entities and cannot be processed"
+ ):
+ self.message = message
+ super().__init__(self.message)
+
+
+class ModerationToxicityError(Exception):
+ """Exception raised if Toxic entities are detected.
+
+ Attributes:
+ message -- explanation of the error
+ """
+
+ def __init__(
+ self, message: str = "The prompt contains toxic content and cannot be processed"
+ ):
+ self.message = message
+ super().__init__(self.message)
+
+
+class ModerationIntentionError(Exception):
+ """Exception raised if Intention entities are detected.
+
+ Attributes:
+ message -- explanation of the error
+ """
+
+ def __init__(
+ self,
+ message: str = (
+ "The prompt indicates an un-desired intent and " "cannot be processed"
+ ),
+ ):
+ self.message = message
+ super().__init__(self.message)
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/intent.py b/libs/experimental/langchain_experimental/comprehend_moderation/intent.py
new file mode 100644
index 00000000000..761c0728689
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/intent.py
@@ -0,0 +1,101 @@
+import asyncio
+import warnings
+from typing import Any, Dict, Optional
+
+from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
+ ModerationIntentionError,
+)
+
+
+class ComprehendIntent:
+ def __init__(
+ self,
+ client: Any,
+ callback: Optional[Any] = None,
+ unique_id: Optional[str] = None,
+ chain_id: Optional[str] = None,
+ ) -> None:
+ self.client = client
+ self.moderation_beacon = {
+ "moderation_chain_id": chain_id,
+ "moderation_type": "Intent",
+ "moderation_status": "LABELS_NOT_FOUND",
+ }
+ self.callback = callback
+ self.unique_id = unique_id
+
+ def _get_arn(self) -> str:
+ region_name = self.client.meta.region_name
+ service = "comprehend"
+ intent_endpoint = "document-classifier-endpoint/prompt-intent"
+ return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
+
+ def validate(
+ self, prompt_value: str, config: Optional[Dict[str, Any]] = None
+ ) -> str:
+ """
+ Check and validate the intent of the given prompt text.
+
+ Args:
+ comprehend_client: Comprehend client for intent classification
+ prompt_value (str): The input text to be checked for unintended intent
+ config (Dict[str, Any]): Configuration settings for intent checks
+
+ Raises:
+ ValueError: If unintended intent is found in the prompt text based
+ on the specified threshold.
+
+ Returns:
+ str: The input prompt_value.
+
+ Note:
+ This function checks the intent of the provided prompt text using
+ Comprehend's classify_document API and raises an error if unintended
+ intent is detected with a score above the specified threshold.
+
+ """
+ from langchain_experimental.comprehend_moderation.base_moderation_enums import (
+ BaseModerationActions,
+ )
+
+ threshold = config.get("threshold", 0.5) if config else 0.5
+ action = (
+ config.get("action", BaseModerationActions.STOP)
+ if config
+ else BaseModerationActions.STOP
+ )
+ intent_found = False
+
+ if action == BaseModerationActions.ALLOW:
+ warnings.warn(
+ "You have allowed content with Harmful content."
+ "Defaulting to STOP action..."
+ )
+ action = BaseModerationActions.STOP
+
+ endpoint_arn = self._get_arn()
+ response = self.client.classify_document(
+ Text=prompt_value, EndpointArn=endpoint_arn
+ )
+
+ if self.callback and self.callback.intent_callback:
+ self.moderation_beacon["moderation_input"] = prompt_value
+ self.moderation_beacon["moderation_output"] = response
+
+ for class_result in response["Classes"]:
+ if (
+ class_result["Score"] >= threshold
+ and class_result["Name"] == "UNDESIRED_PROMPT"
+ ):
+ intent_found = True
+ break
+
+ if self.callback and self.callback.intent_callback:
+ if intent_found:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
+ )
+ if intent_found:
+ raise ModerationIntentionError
+ return prompt_value
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/pii.py b/libs/experimental/langchain_experimental/comprehend_moderation/pii.py
new file mode 100644
index 00000000000..2c82b7a4004
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/pii.py
@@ -0,0 +1,173 @@
+import asyncio
+from typing import Any, Dict, Optional
+
+from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
+ ModerationPiiError,
+)
+
+
+class ComprehendPII:
+ def __init__(
+ self,
+ client: Any,
+ callback: Optional[Any] = None,
+ unique_id: Optional[str] = None,
+ chain_id: Optional[str] = None,
+ ) -> None:
+ self.client = client
+ self.moderation_beacon = {
+ "moderation_chain_id": chain_id,
+ "moderation_type": "PII",
+ "moderation_status": "LABELS_NOT_FOUND",
+ }
+ self.callback = callback
+ self.unique_id = unique_id
+
+ def validate(
+ self, prompt_value: str, config: Optional[Dict[str, Any]] = None
+ ) -> str:
+ from langchain_experimental.comprehend_moderation.base_moderation_enums import (
+ BaseModerationActions,
+ )
+
+ if config:
+ action = config.get("action", BaseModerationActions.STOP)
+ if action not in [BaseModerationActions.STOP, BaseModerationActions.ALLOW]:
+ raise ValueError("Action can either be stop or allow")
+
+ return (
+ self._contains_pii(prompt_value=prompt_value, config=config)
+ if action == BaseModerationActions.STOP
+ else self._detect_pii(prompt_value=prompt_value, config=config)
+ )
+ else:
+ return self._contains_pii(prompt_value=prompt_value)
+
+ def _contains_pii(
+ self, prompt_value: str, config: Optional[Dict[str, Any]] = None
+ ) -> str:
+ """
+ Checks for Personally Identifiable Information (PII) labels above a
+ specified threshold.
+
+ Args:
+ prompt_value (str): The input text to be checked for PII labels.
+ config (Dict[str, Any]): Configuration for PII check and actions.
+
+ Returns:
+ str: the original prompt
+
+ Note:
+ - The provided client should be initialized with valid AWS credentials.
+ """
+ pii_identified = self.client.contains_pii_entities(
+ Text=prompt_value, LanguageCode="en"
+ )
+
+ if self.callback and self.callback.pii_callback:
+ self.moderation_beacon["moderation_input"] = prompt_value
+ self.moderation_beacon["moderation_output"] = pii_identified
+
+ threshold = config.get("threshold", 0.5) if config else 0.5
+ pii_labels = config.get("labels", []) if config else []
+ pii_found = False
+ for entity in pii_identified["Labels"]:
+ if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
+ entity["Score"] >= threshold and not pii_labels
+ ):
+ pii_found = True
+ break
+
+ if self.callback and self.callback.pii_callback:
+ if pii_found:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
+ )
+ if pii_found:
+ raise ModerationPiiError
+ return prompt_value
+
+ def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
+ """
+ Detects and handles Personally Identifiable Information (PII) entities in the
+ given prompt text using Amazon Comprehend's detect_pii_entities API. The
+ function provides options to redact or stop processing based on the identified
+ PII entities and a provided configuration.
+
+ Args:
+ prompt_value (str): The input text to be checked for PII entities.
+ config (Dict[str, Any]): A configuration specifying how to handle
+ PII entities.
+
+ Returns:
+ str: The processed prompt text with redacted PII entities or raised
+ exceptions.
+
+ Raises:
+ ValueError: If the prompt contains configured PII entities for
+ stopping processing.
+
+ Note:
+ - If PII is not found in the prompt, the original prompt is returned.
+ - The client should be initialized with valid AWS credentials.
+ """
+ pii_identified = self.client.detect_pii_entities(
+ Text=prompt_value, LanguageCode="en"
+ )
+
+ if self.callback and self.callback.pii_callback:
+ self.moderation_beacon["moderation_input"] = prompt_value
+ self.moderation_beacon["moderation_output"] = pii_identified
+
+ if (pii_identified["Entities"]) == []:
+ if self.callback and self.callback.pii_callback:
+ asyncio.create_task(
+ self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
+ )
+ return prompt_value
+
+ pii_found = False
+ if not config and pii_identified["Entities"]:
+ for entity in pii_identified["Entities"]:
+ if entity["Score"] >= 0.5:
+ pii_found = True
+ break
+
+ if self.callback and self.callback.pii_callback:
+ if pii_found:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
+ )
+ if pii_found:
+ raise ModerationPiiError
+ else:
+ threshold = config.get("threshold", 0.5) # type: ignore
+ pii_labels = config.get("labels", []) # type: ignore
+ mask_marker = config.get("mask_character", "*") # type: ignore
+ pii_found = False
+
+ for entity in pii_identified["Entities"]:
+ if (
+ pii_labels
+ and entity["Type"] in pii_labels
+ and entity["Score"] >= threshold
+ ) or (not pii_labels and entity["Score"] >= threshold):
+ pii_found = True
+ char_offset_begin = entity["BeginOffset"]
+ char_offset_end = entity["EndOffset"]
+ prompt_value = (
+ prompt_value[:char_offset_begin]
+ + mask_marker * (char_offset_end - char_offset_begin)
+ + prompt_value[char_offset_end:]
+ )
+
+ if self.callback and self.callback.pii_callback:
+ if pii_found:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
+ )
+
+ return prompt_value
diff --git a/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py b/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py
new file mode 100644
index 00000000000..b66320ec552
--- /dev/null
+++ b/libs/experimental/langchain_experimental/comprehend_moderation/toxicity.py
@@ -0,0 +1,209 @@
+import asyncio
+import importlib
+import warnings
+from typing import Any, Dict, List, Optional
+
+from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
+ ModerationToxicityError,
+)
+
+
+class ComprehendToxicity:
+ def __init__(
+ self,
+ client: Any,
+ callback: Optional[Any] = None,
+ unique_id: Optional[str] = None,
+ chain_id: Optional[str] = None,
+ ) -> None:
+ self.client = client
+ self.moderation_beacon = {
+ "moderation_chain_id": chain_id,
+ "moderation_type": "Toxicity",
+ "moderation_status": "LABELS_NOT_FOUND",
+ }
+ self.callback = callback
+ self.unique_id = unique_id
+
+ def _toxicity_init_validate(self, max_size: int) -> Any:
+ """
+ Validate and initialize toxicity processing configuration.
+
+ Args:
+ max_size (int): Maximum sentence size defined in the configuration object.
+
+ Raises:
+ Exception: If the maximum sentence size exceeds the 5KB limit.
+
+ Note:
+ This function ensures that the NLTK punkt tokenizer is downloaded if not
+ already present.
+
+ Returns:
+ None
+ """
+ if max_size > 1024 * 5:
+ raise Exception("The sentence length should not exceed 5KB.")
+ try:
+ nltk = importlib.import_module("nltk")
+ nltk.data.find("tokenizers/punkt")
+ return nltk
+ except ImportError:
+ raise ModuleNotFoundError(
+ "Could not import nltk python package. "
+ "Please install it with `pip install nltk`."
+ )
+ except LookupError:
+ nltk.download("punkt")
+
+ def _split_paragraph(
+ self, prompt_value: str, max_size: int = 1024 * 4
+ ) -> List[List[str]]:
+ """
+ Split a paragraph into chunks of sentences, respecting the maximum size limit.
+
+ Args:
+ paragraph (str): The input paragraph to be split into chunks
+ max_size (int, optional): The maximum size limit in bytes for each chunk
+ Defaults to 1024.
+
+ Returns:
+ List[List[str]]: A list of chunks, where each chunk is a list of sentences
+
+ Note:
+ This function validates the maximum sentence size based on service limits
+ using the 'toxicity_init_validate' function. It uses the NLTK sentence
+ tokenizer to split the paragraph into sentences.
+
+ """
+
+ # validate max. sentence size based on Service limits
+ nltk = self._toxicity_init_validate(max_size)
+
+ sentences = nltk.sent_tokenize(prompt_value)
+
+ chunks = []
+ current_chunk = [] # type: ignore
+ current_size = 0
+
+ for sentence in sentences:
+ sentence_size = len(sentence.encode("utf-8"))
+
+ # If adding a new sentence exceeds max_size or
+ # current_chunk has 10 sentences, start a new chunk
+ if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
+ if current_chunk: # Avoid appending empty chunks
+ chunks.append(current_chunk)
+ current_chunk = []
+ current_size = 0
+
+ current_chunk.append(sentence)
+ current_size += sentence_size
+
+ # Add any remaining sentences
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ return chunks
+
+ def validate(
+ self, prompt_value: str, config: Optional[Dict[str, Any]] = None
+ ) -> str:
+ """
+ Check the toxicity of a given text prompt using AWS Comprehend service
+ and apply actions based on configuration.
+
+ Args:
+ prompt_value (str): The text content to be checked for toxicity.
+ config (Dict[str, Any]): Configuration for toxicity checks and actions.
+
+ Returns:
+ str: The original prompt_value if allowed or no toxicity found.
+
+ Raises:
+ ValueError: If the prompt contains toxic labels and cannot be
+ processed based on the configuration.
+ """
+
+ chunks = self._split_paragraph(prompt_value=prompt_value)
+ for sentence_list in chunks:
+ segments = [{"Text": sentence} for sentence in sentence_list]
+ response = self.client.detect_toxic_content(
+ TextSegments=segments, LanguageCode="en"
+ )
+ if self.callback and self.callback.toxicity_callback:
+ self.moderation_beacon["moderation_input"] = segments # type: ignore
+ self.moderation_beacon["moderation_output"] = response
+
+ if config:
+ from langchain_experimental.comprehend_moderation.base_moderation_enums import ( # noqa: E501
+ BaseModerationActions,
+ )
+
+ toxicity_found = False
+ action = config.get("action", BaseModerationActions.STOP)
+ if action not in [
+ BaseModerationActions.STOP,
+ BaseModerationActions.ALLOW,
+ ]:
+ raise ValueError("Action can either be stop or allow")
+
+ threshold = config.get("threshold", 0.5) if config else 0.5
+ toxicity_labels = config.get("labels", []) if config else []
+
+ if action == BaseModerationActions.STOP:
+ for item in response["ResultList"]:
+ for label in item["Labels"]:
+ if (
+ label
+ and (
+ not toxicity_labels
+ or label["Name"] in toxicity_labels
+ )
+ and label["Score"] >= threshold
+ ):
+ toxicity_found = True
+ break
+
+ if action == BaseModerationActions.ALLOW:
+ if not toxicity_labels:
+ warnings.warn(
+ "You have allowed toxic content without specifying "
+ "any toxicity labels."
+ )
+ else:
+ for item in response["ResultList"]:
+ for label in item["Labels"]:
+ if (
+ label["Name"] in toxicity_labels
+ and label["Score"] >= threshold
+ ):
+ toxicity_found = True
+ break
+
+ if self.callback and self.callback.toxicity_callback:
+ if toxicity_found:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_toxicity(
+ self.moderation_beacon, self.unique_id
+ )
+ )
+ if toxicity_found:
+ raise ModerationToxicityError
+ else:
+ if response["ResultList"]:
+ detected_toxic_labels = list()
+ for item in response["ResultList"]:
+ detected_toxic_labels.extend(item["Labels"])
+ if any(item["Score"] >= 0.5 for item in detected_toxic_labels):
+ if self.callback and self.callback.toxicity_callback:
+ self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
+ asyncio.create_task(
+ self.callback.on_after_toxicity(
+ self.moderation_beacon, self.unique_id
+ )
+ )
+ raise ModerationToxicityError
+
+ return prompt_value