feat: Add PromptGuard integration (#9481)

Add PromptGuard integration
-------
There are two approaches to integrate PromptGuard with a LangChain
application.

1. PromptGuardLLMWrapper
2. functions that can be used in LangChain expression.

-----
- Dependencies
`promptguard` python package, which is a runtime requirement if you'd
try out the demo.

- @baskaryan @hwchase17 Thanks for the ideas and suggestions along the
development process.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Zizhong Zhang
2023-08-21 14:59:36 -07:00
committed by GitHub
parent 6c308aabae
commit 00eff8c4a7
5 changed files with 516 additions and 0 deletions

View File

@@ -69,6 +69,7 @@ from langchain.llms.petals import Petals
from langchain.llms.pipelineai import PipelineAI
from langchain.llms.predibase import Predibase
from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptguard import PromptGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
from langchain.llms.replicate import Replicate
from langchain.llms.rwkv import RWKV
@@ -141,6 +142,7 @@ __all__ = [
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
"PromptGuard",
"RWKV",
"Replicate",
"SagemakerEndpoint",
@@ -205,6 +207,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"petals": Petals,
"pipelineai": PipelineAI,
"predibase": Predibase,
"promptguard": PromptGuard,
"replicate": Replicate,
"rwkv": RWKV,
"sagemaker_endpoint": SagemakerEndpoint,

View File

@@ -0,0 +1,116 @@
import logging
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema.language_model import BaseLanguageModel
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PromptGuard(LLM):
"""An LLM wrapper that uses PromptGuard to sanitize prompts.
Wraps another LLM and sanitizes prompts before passing it to the LLM, then
de-sanitizes the response.
To use, you should have the ``promptguard`` python package installed,
and the environment variable ``PROMPTGUARD_API_KEY`` set with
your API key, or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms import PromptGuardLLM
from langchain.chat_models import ChatOpenAI
prompt_guard_llm = PromptGuardLLM(base_llm=ChatOpenAI())
"""
base_llm: BaseLanguageModel
"""The base LLM to use."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the PromptGuard API key and the Python package exist."""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
if pg.__package__ is None:
raise ValueError(
"Could not properly import `promptguard`, "
"promptguard.__package__ is None."
)
api_key = get_from_dict_or_env(
values, "promptguard_api_key", "PROMPTGUARD_API_KEY", default=""
)
if not api_key:
raise ValueError(
"Could not find PROMPTGUARD_API_KEY in the environment. "
"Please set it to your PromptGuard API key."
"You can get it by creating an account on the PromptGuard website: "
"https://promptguard.opaque.co/ ."
)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call base LLM with sanitization before and de-sanitization after.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = prompt_guard_llm("Tell me a joke.")
"""
import promptguard as pg
_run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
# sanitize the prompt by replacing the sensitive information with a placeholder
sanitize_response: pg.SanitizeResponse = pg.sanitize(prompt)
sanitized_prompt_value_str = sanitize_response.sanitized_text
# TODO: Add in callbacks once child runs for LLMs are supported by LangSmith.
# call the LLM with the sanitized prompt and get the response
llm_response = self.base_llm.predict(
sanitized_prompt_value_str,
stop=stop,
)
# desanitize the response by restoring the original sensitive information
desanitize_response: pg.DesanitizeResponse = pg.desanitize(
llm_response,
secure_context=sanitize_response.secure_context,
)
return desanitize_response.desanitized_text
@property
def _llm_type(self) -> str:
"""Return type of LLM.
This is an override of the base class method.
"""
return "promptguard"

View File

@@ -0,0 +1,99 @@
import json
from typing import Dict, Union
def sanitize(
input: Union[str, Dict[str, str]]
) -> Dict[str, Union[str, Dict[str, str]]]:
"""
Sanitize input string or dict of strings by replacing sensitive data with
placeholders.
It returns the sanitized input string or dict of strings and the secure
context as a dict following the format:
{
"sanitized_input": <sanitized input string or dict of strings>,
"secure_context": <secure context>
}
The secure context is a bytes object that is needed to de-sanitize the response
from the LLM.
Args:
input: Input string or dict of strings.
Returns:
Sanitized input string or dict of strings and the secure context
as a dict following the format:
{
"sanitized_input": <sanitized input string or dict of strings>,
"secure_context": <secure context>
}
The `secure_context` needs to be passed to the `desanitize` function.
"""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
if isinstance(input, str):
# the input could be a string, so we sanitize the string
sanitize_response: pg.SanitizeResponse = pg.sanitize(input)
return {
"sanitized_input": sanitize_response.sanitized_text,
"secure_context": sanitize_response.secure_context,
}
if isinstance(input, dict):
# the input could be a dict[string, string], so we sanitize the values
values = list()
# get the values from the dict
for key in input:
values.append(input[key])
input_value_str = json.dumps(values)
# sanitize the values
sanitize_values_response: pg.SanitizeResponse = pg.sanitize(input_value_str)
# reconstruct the dict with the sanitized values
sanitized_input_values = json.loads(sanitize_values_response.sanitized_text)
idx = 0
sanitized_input = dict()
for key in input:
sanitized_input[key] = sanitized_input_values[idx]
idx += 1
return {
"sanitized_input": sanitized_input,
"secure_context": sanitize_values_response.secure_context,
}
raise ValueError(f"Unexpected input type {type(input)}")
def desanitize(sanitized_text: str, secure_context: bytes) -> str:
"""
Restore the original sensitive data from the sanitized text.
Args:
sanitized_text: Sanitized text.
secure_context: Secure context returned by the `sanitize` function.
Returns:
De-sanitized text.
"""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
desanitize_response: pg.DesanitizeResponse = pg.desanitize(
sanitized_text, secure_context
)
return desanitize_response.desanitized_text

View File

@@ -0,0 +1,84 @@
import langchain.utilities.promptguard as pgf
from langchain import LLMChain, PromptTemplate
from langchain.llms import OpenAI
from langchain.llms.promptguard import PromptGuard
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap
prompt_template = """
As an AI assistant, you will answer questions according to given context.
Sensitive personal information in the question is masked for privacy.
For instance, if the original text says "Giana is good," it will be changed
to "PERSON_998 is good."
Here's how to handle these changes:
* Consider these masked phrases just as placeholders, but still refer to
them in a relevant way when answering.
* It's possible that different masked terms might mean the same thing.
Stick with the given term and don't modify it.
* All masked terms follow the "TYPE_ID" pattern.
* Please don't invent new masked terms. For instance, if you see "PERSON_998,"
don't come up with "PERSON_997" or "PERSON_999" unless they're already in the question.
Conversation History: ```{history}```
Context : ```During our recent meeting on February 23, 2023, at 10:30 AM,
John Doe provided me with his personal details. His email is johndoe@example.com
and his contact number is 650-456-7890. He lives in New York City, USA, and
belongs to the American nationality with Christian beliefs and a leaning towards
the Democratic party. He mentioned that he recently made a transaction using his
credit card 4111 1111 1111 1111 and transferred bitcoins to the wallet address
1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa. While discussing his European travels, he
noted down his IBAN as GB29 NWBK 6016 1331 9268 19. Additionally, he provided
his website as https://johndoeportfolio.com. John also discussed
some of his US-specific details. He said his bank account number is
1234567890123456 and his drivers license is Y12345678. His ITIN is 987-65-4321,
and he recently renewed his passport,
the number for which is 123456789. He emphasized not to share his SSN, which is
669-45-6789. Furthermore, he mentioned that he accesses his work files remotely
through the IP 192.168.1.1 and has a medical license number MED-123456. ```
Question: ```{question}```
"""
def test_promptguard() -> None:
chain = LLMChain(
prompt=PromptTemplate.from_template(prompt_template),
llm=PromptGuard(llm=OpenAI()),
memory=ConversationBufferWindowMemory(k=2),
)
output = chain.run(
{
"question": "Write a text message to remind John to do password reset \
for his website through his email to stay secure."
}
)
assert isinstance(output, str)
def test_promptguard_functions() -> None:
prompt = (PromptTemplate.from_template(prompt_template),)
llm = OpenAI()
pg_chain = (
pgf.sanitize
| RunnableMap(
{
"response": (lambda x: x["sanitized_input"]) # type: ignore
| prompt
| llm
| StrOutputParser(),
"secure_context": lambda x: x["secure_context"],
}
)
| (lambda x: pgf.desanitize(x["response"], x["secure_context"]))
)
pg_chain.invoke(
{
"question": "Write a text message to remind John to do password reset\
for his website through his email to stay secure.",
"history": "",
}
)