feat: PromptGuard integration

This commit is contained in:
Zizhong Zhang
2023-08-18 15:00:31 -07:00
committed by GitHub
parent b58d492e05
commit 5f87267626
4 changed files with 517 additions and 0 deletions

View File

@@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PromptGuard\n",
"\n",
"[PromptGuard](https://promptguard.readthedocs.io/en/latest/) is a service that enables applications to leverage the power of language models without compromising user privacy. Designed for composability and ease of integration into existing applications and services, PromptGuard is consumable via a simple Python library as well as through LangChain. Perhaps more importantly, PromptGuard leverages the power of [confidential computing](https://en.wikipedia.org/wiki/Confidential_computing) to ensure that even the PromptGuard service itself cannot access the data it is protecting.\n",
" \n",
"\n",
"This notebook goes over how to use LangChain to interact with `PromptGuard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# install the promptguard and langchain packages\n",
"! pip install promptguard langchain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Accessing the PromptGuard API requires an API key, which you can get by creating an account on [the PromptGuard website](https://promptguard.opaque.co/). Once you have an account, you can find your API key on [the API Keys page](https://promptguard.opaque.co/api-keys)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Set API keys\n",
"\n",
"os.environ['PROMPT_GUARD_API_KEY'] = \"<PROMPT_GUARD_API_KEY>\"\n",
"os.environ['OPENAI_API_KEY'] = \"<OPENAI_API_KEY>\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use PromptGuardLLMWrapper\n",
"\n",
"Applying promptguard to your application could be as simple as wrapping your LLM using the PromptGuardLLMWrapper class by replace `llm=OpenAI()` with `llm=PromptGuardLLMWrapper(OpenAI())`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import langchain\n",
"from langchain import LLMChain, PromptTemplate\n",
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
"from langchain.llms import OpenAI\n",
"from langchain.memory import ConversationBufferWindowMemory\n",
"\n",
"from langchain.llms import PromptGuardLLMWrapper\n",
"\n",
"langchain.verbose = True\n",
"langchain.debug = True\n",
"\n",
"prompt_template = \"\"\"\n",
"As an AI assistant, you will answer questions according to given context.\n",
"\n",
"Sensitive personal information in the question is masked for privacy.\n",
"For instance, if the original text says \"Giana is good,\" it will be changed\n",
"to \"PERSON_998 is good.\" \n",
"\n",
"Here's how to handle these changes:\n",
"* Consider these masked phrases just as placeholders, but still refer to\n",
"them in a relevant way when answering.\n",
"* It's possible that different masked terms might mean the same thing.\n",
"Stick with the given term and don't modify it.\n",
"* All masked terms follow the \"TYPE_ID\" pattern.\n",
"* Please don't invent new masked terms. For instance, if you see \"PERSON_998,\"\n",
"don't come up with \"PERSON_997\" or \"PERSON_999\" unless they're already in the question.\n",
"\n",
"Conversation History: ```{history}```\n",
"Context : ```During our recent meeting on February 23, 2023, at 10:30 AM,\n",
"John Doe provided me with his personal details. His email is johndoe@example.com\n",
"and his contact number is 650-456-7890. He lives in New York City, USA, and\n",
"belongs to the American nationality with Christian beliefs and a leaning towards\n",
"the Democratic party. He mentioned that he recently made a transaction using his\n",
"credit card 4111 1111 1111 1111 and transferred bitcoins to the wallet address\n",
"1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa. While discussing his European travels, he noted\n",
"down his IBAN as GB29 NWBK 6016 1331 9268 19. Additionally, he provided his website\n",
"as https://johndoeportfolio.com. John also discussed some of his US-specific details.\n",
"He said his bank account number is 1234567890123456 and his drivers license is Y12345678.\n",
"His ITIN is 987-65-4321, and he recently renewed his passport, the number for which is\n",
"123456789. He emphasized not to share his SSN, which is 123-45-6789. Furthermore, he\n",
"mentioned that he accesses his work files remotely through the IP 192.168.1.1 and has\n",
"a medical license number MED-123456. ```\n",
"Question: ```{question}```\n",
"\n",
"\"\"\"\n",
"\n",
"chain = LLMChain(\n",
" prompt=PromptTemplate.from_template(prompt_template),\n",
" llm=PromptGuardLLMWrapper(llm=OpenAI()),\n",
" memory=ConversationBufferWindowMemory(k=2),\n",
" verbose=True,\n",
")\n",
"\n",
"\n",
"print(\n",
" chain.run(\n",
" {\"question\": \"\"\"Write a message to remind John to do password reset for his website to stay secure.\"\"\"},\n",
" callbacks=[StdOutCallbackHandler()],\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the output, you can see the following context from user input has sensitive data.\n",
"\n",
"``` \n",
"# Context from user input\n",
"\n",
"During our recent meeting on February 23, 2023, at 10:30 AM, John Doe provided me with his personal details. His email is johndoe@example.com and his contact number is 650-456-7890. He lives in New York City, USA, and belongs to the American nationality with Christian beliefs and a leaning towards the Democratic party. He mentioned that he recently made a transaction using his credit card 4111 1111 1111 1111 and transferred bitcoins to the wallet address 1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa. While discussing his European travels, he noted down his IBAN as GB29 NWBK 6016 1331 9268 19. Additionally, he provided his website as https://johndoeportfolio.com. John also discussed some of his US-specific details. He said his bank account number is 1234567890123456 and his drivers license is Y12345678. His ITIN is 987-65-4321, and he recently renewed his passport, the number for which is 123456789. He emphasized not to share his SSN, which is 669-45-6789. Furthermore, he mentioned that he accesses his work files remotely through the IP 192.168.1.1 and has a medical license number MED-123456.\n",
"```\n",
"\n",
"PromptGuard will automatically detect the sensitive data and replace it with a placeholder. \n",
"\n",
"```\n",
"# Context after PromptGuard\n",
"\n",
"During our recent meeting on DATE_TIME_3, at DATE_TIME_2, PERSON_3 provided me with his personal details. His email is EMAIL_ADDRESS_1 and his contact number is PHONE_NUMBER_1. He lives in LOCATION_3, LOCATION_2, and belongs to the NRP_3 nationality with NRP_2 beliefs and a leaning towards the Democratic party. He mentioned that he recently made a transaction using his credit card CREDIT_CARD_1 and transferred bitcoins to the wallet address CRYPTO_1. While discussing his NRP_1 travels, he noted down his IBAN as IBAN_CODE_1. Additionally, he provided his website as URL_1. PERSON_2 also discussed some of his LOCATION_1-specific details. He said his bank account number is US_BANK_NUMBER_1 and his drivers license is US_DRIVER_LICENSE_2. His ITIN is US_ITIN_1, and he recently renewed his passport, the number for which is DATE_TIME_1. He emphasized not to share his SSN, which is US_SSN_1. Furthermore, he mentioned that he accesses his work files remotely through the IP IP_ADDRESS_1 and has a medical license number MED-US_DRIVER_LICENSE_1.\n",
"```\n",
"\n",
"Placeholder is used in the LLM response.\n",
"\n",
"```\n",
"# response returned by LLM\n",
"\n",
"Hey PERSON_1, just wanted to remind you to do a password reset for your website URL_1 through your email EMAIL_ADDRESS_1. It's important to stay secure online, so don't forget to do it!\n",
"```\n",
"\n",
"Response is desanitized by replacing the placeholder with the original sensitive data.\n",
"\n",
"```\n",
"# desanitized LLM response from PromptGuard\n",
"\n",
"Hey John, just wanted to remind you to do a password reset for your website https://johndoeportfolio.com through your email johndoe@example.com. It's important to stay secure online, so don't forget to do it!\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Use PromptGuard in LangChain expression\n",
"\n",
"There are functions that can be used with LangChain expression as well if a drop-in replacement doesn't offer the flexibility you need. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import langchain.utilities.promptguard as pgf\n",
"from langchain.schema.runnable import RunnableMap\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"\n",
"\n",
"prompt=PromptTemplate.from_template(prompt_template), \n",
"llm = OpenAI()\n",
"pg_chain = (\n",
" pgf.sanitize\n",
" | RunnableMap(\n",
" {\n",
" \"response\": (lambda x: x[\"sanitized_input\"])\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser(),\n",
" \"secure_context\": lambda x: x[\"secure_context\"],\n",
" }\n",
" )\n",
" | (lambda x: pgf.desanitize(x[\"response\"], x[\"secure_context\"]))\n",
")\n",
"\n",
"pg_chain.invoke({\"question\": \"Write a text message to remind John to do password reset for his website through his email to stay secure.\", \"history\": \"\"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,111 @@
import logging
from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.prompts.base import StringPromptValue
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PromptGuardLLMWrapper(LLM):
"""
An LLM wrapper that uses PromptGuard to sanitize the prompt before
passing it to the LLM, and that desanitizes the response after
getting it from the LLM.
To use, you should have the `promptguard` python package installed,
and the environment variable `PROMPTGUARD_API_KEY` set with
your API key, or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
prompt_guard_llm = PromptGuardLLM(llm=ChatOpenAI())
"""
llm: Any
"""The LLM to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the PromptGuard API key and the Python package exist."""
api_key = get_from_dict_or_env(
values, "promptguard_api_key", "PROMPTGUARD_API_KEY"
)
if api_key is None:
raise ValueError(
"Could not find PROMPTGUARD_API_KEY in the environment. "
"Please set it to your PromptGuard API key."
"You can get it by creating an account on the PromptGuard website: "
"https://promptguard.opaque.co/ ."
)
try:
import promptguard as pg
assert pg.__package__ is not None
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Use PromptGuard to do sanitization and desanitization
before and after running LLM.
This is an override of the base class method.
Parameters
----------
prompt: The prompt to pass into the model.
Returns
-------
The string generated by the model.
Example
-------
.. code-block:: python
response = prompt_guard_llm("Tell me a joke.")
"""
import promptguard as pg
# sanitize the prompt by replacing the sensitive information with a placeholder
sanitize_response: pg.SanitizeResponse = pg.sanitize(prompt)
sanitized_prompt_value_str = sanitize_response.sanitized_text
# call the LLM with the sanitized prompt and get the response
llm_response = self.llm.generate_prompt(
[StringPromptValue(text=sanitized_prompt_value_str)],
)
# desanitize the response by restoring the original sensitive information
desanitize_response: pg.DesanitizeResponse = pg.desanitize(
llm_response.generations[0][0].text,
secure_context=sanitize_response.secure_context,
)
return desanitize_response.desanitized_text
@property
def _llm_type(self) -> str:
"""Return type of LLM.
This is an override of the base class method.
"""
return "promptguard"

View File

@@ -0,0 +1,108 @@
import json
from typing import Dict, Union
def sanitize(
input: Union[str, Dict[str, str]]
) -> Dict[str, Union[str, Dict[str, str]]]:
"""
Sanitize input string or dict of strings by replacing sensitive data with
placeholders.
It returns the sanitized input string or dict of strings and the secure
context as a dict following the format:
{
"sanitized_input": <sanitized input string or dict of strings>,
"secure_context": <secure context>
}
The secure context is a bytes object that is needed to desanitize the response
from the LLM.
Parameters
----------
input : Union[str, Dict[str, str]]
input string or dict of strings
Returns
-------
Dict[str, Union[str, Dict[str, str]]]
sanitized input string or dict of strings and the secure context
as a dict following the format:
{
"sanitized_input": <sanitized input string or dict of strings>,
"secure_context": <secure context>
}
The `secure_context` needs to be passed to the `desanitize` function.
"""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
if isinstance(input, str):
# the input could be a string, so we sanitize the string
sanitize_response: pg.SanitizeResponse = pg.sanitize(input)
return {
"sanitized_input": sanitize_response.sanitized_text,
"secure_context": sanitize_response.secure_context,
}
if isinstance(input, dict):
# the input could be a dict[string, string], so we sanitize the values
values = list()
# get the values from the dict
for key in input:
values.append(input[key])
input_value_str = json.dumps(values)
# sanitize the values
sanitize_values_response: pg.SanitizeResponse = pg.sanitize(input_value_str)
# reconstruct the dict with the sanitized values
sanitized_input_values = json.loads(sanitize_values_response.sanitized_text)
idx = 0
sanitized_input = dict()
for key in input:
sanitized_input[key] = sanitized_input_values[idx]
idx += 1
return {
"sanitized_input": sanitized_input,
"secure_context": sanitize_values_response.secure_context,
}
raise ValueError(f"Unexpected input type {type(input)}")
def desanitize(sanitized_text: str, secure_context: bytes) -> str:
"""
Restore the original sensitive data from the sanitized text.
Parameters
----------
sanitized_text : str
sanitized text
secure_context : bytes
secure context returned by the `sanitize` function
Returns
-------
str
desanitized text
"""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
desanitize_response: pg.DesanitizeResponse = pg.desanitize(
sanitized_text, secure_context
)
return desanitize_response.desanitized_text

View File

@@ -0,0 +1,84 @@
import langchain.utilities.promptguard as pgf
from langchain import LLMChain, PromptTemplate
from langchain.llms import OpenAI
from langchain.llms.promptguard import PromptGuardLLMWrapper
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap
prompt_template = """
As an AI assistant, you will answer questions according to given context.
Sensitive personal information in the question is masked for privacy.
For instance, if the original text says "Giana is good," it will be changed
to "PERSON_998 is good."
Here's how to handle these changes:
* Consider these masked phrases just as placeholders, but still refer to
them in a relevant way when answering.
* It's possible that different masked terms might mean the same thing.
Stick with the given term and don't modify it.
* All masked terms follow the "TYPE_ID" pattern.
* Please don't invent new masked terms. For instance, if you see "PERSON_998,"
don't come up with "PERSON_997" or "PERSON_999" unless they're already in the question.
Conversation History: ```{history}```
Context : ```During our recent meeting on February 23, 2023, at 10:30 AM,
John Doe provided me with his personal details. His email is johndoe@example.com
and his contact number is 650-456-7890. He lives in New York City, USA, and
belongs to the American nationality with Christian beliefs and a leaning towards
the Democratic party. He mentioned that he recently made a transaction using his
credit card 4111 1111 1111 1111 and transferred bitcoins to the wallet address
1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa. While discussing his European travels, he
noted down his IBAN as GB29 NWBK 6016 1331 9268 19. Additionally, he provided
his website as https://johndoeportfolio.com. John also discussed
some of his US-specific details. He said his bank account number is
1234567890123456 and his drivers license is Y12345678. His ITIN is 987-65-4321,
and he recently renewed his passport,
the number for which is 123456789. He emphasized not to share his SSN, which is
669-45-6789. Furthermore, he mentioned that he accesses his work files remotely
through the IP 192.168.1.1 and has a medical license number MED-123456. ```
Question: ```{question}```
"""
def test_promptguard_llm_wrapper() -> None:
chain = LLMChain(
prompt=PromptTemplate.from_template(prompt_template),
llm=PromptGuardLLMWrapper(llm=OpenAI()),
memory=ConversationBufferWindowMemory(k=2),
)
output = chain.run(
{
"question": "Write a text message to remind John to do password reset \
for his website through his email to stay secure."
}
)
assert isinstance(output, str)
def test_promptguard_functions() -> None:
prompt = (PromptTemplate.from_template(prompt_template),)
llm = OpenAI()
pg_chain = (
pgf.sanitize
| RunnableMap(
{
"response": (lambda x: x["sanitized_input"]) # type: ignore
| prompt
| llm
| StrOutputParser(),
"secure_context": lambda x: x["secure_context"],
}
)
| (lambda x: pgf.desanitize(x["response"], x["secure_context"]))
)
pg_chain.invoke(
{
"question": "Write a text message to remind John to do password reset\
for his website through his email to stay secure.",
"history": "",
}
)