This commit is contained in:
Bagatur
2023-08-18 18:31:39 -07:00
parent b8aa62d361
commit 94becbf192
4 changed files with 47 additions and 40 deletions

View File

@@ -582,7 +582,7 @@ class AsyncParentRunManager(AsyncRunManager):
return manager
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
class CallbackManagerForLLMRun(ParentRunManager, LLMManagerMixin):
"""Callback manager for LLM run."""
def on_llm_new_token(
@@ -645,7 +645,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
)
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class AsyncCallbackManagerForLLMRun(AsyncParentRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
async def on_llm_new_token(

View File

@@ -69,6 +69,7 @@ from langchain.llms.petals import Petals
from langchain.llms.pipelineai import PipelineAI
from langchain.llms.predibase import Predibase
from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptguard import PromptGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
from langchain.llms.replicate import Replicate
from langchain.llms.rwkv import RWKV
@@ -141,6 +142,7 @@ __all__ = [
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
"PromptGuard",
"RWKV",
"Replicate",
"SagemakerEndpoint",
@@ -205,6 +207,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"petals": Petals,
"pipelineai": PipelineAI,
"predibase": Predibase,
"promptguard": PromptGuard,
"replicate": Replicate,
"rwkv": RWKV,
"sagemaker_endpoint": SagemakerEndpoint,

View File

@@ -3,31 +3,34 @@ from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.prompts.base import StringPromptValue
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema.language_model import BaseLanguageModel
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PromptGuardLLMWrapper(LLM):
"""
An LLM wrapper that uses PromptGuard to sanitize the prompt before
passing it to the LLM, and that desanitizes the response after
getting it from the LLM.
class PromptGuard(LLM):
"""An LLM wrapper that uses PromptGuard to sanitize prompts.
To use, you should have the `promptguard` python package installed,
and the environment variable `PROMPTGUARD_API_KEY` set with
Wraps another LLM and sanitizes prompts before passing it to the LLM, then
de-sanitizes the response.
To use, you should have the ``promptguard`` python package installed,
and the environment variable ``PROMPTGUARD_API_KEY`` set with
your API key, or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
prompt_guard_llm = PromptGuardLLM(llm=ChatOpenAI())
from langchain.llms import PromptGuardLLM
from langchain.chat_models import ChatOpenAI
prompt_guard_llm = PromptGuardLLM(base_llm=ChatOpenAI())
"""
llm: Any
"""The LLM to use."""
base_llm: BaseLanguageModel
"""The base LLM to use."""
class Config:
"""Configuration for this pydantic object."""
@@ -37,25 +40,29 @@ class PromptGuardLLMWrapper(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the PromptGuard API key and the Python package exist."""
try:
import promptguard as pg
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
if pg.__package__ is None:
raise ValueError(
"Could not properly import `promptguard`, "
"promptguard.__package__ is None."
)
api_key = get_from_dict_or_env(
values, "promptguard_api_key", "PROMPTGUARD_API_KEY"
values, "promptguard_api_key", "PROMPTGUARD_API_KEY", default=""
)
if api_key is None:
if not api_key:
raise ValueError(
"Could not find PROMPTGUARD_API_KEY in the environment. "
"Please set it to your PromptGuard API key."
"You can get it by creating an account on the PromptGuard website: "
"https://promptguard.opaque.co/ ."
)
try:
import promptguard as pg
assert pg.__package__ is not None
except ImportError:
raise ImportError(
"Could not import the `promptguard` Python package, "
"please install it with `pip install promptguard`."
)
return values
def _call(
@@ -65,38 +72,35 @@ class PromptGuardLLMWrapper(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Use PromptGuard to do sanitization and desanitization
before and after running LLM.
"""Call base LLM with sanitization before and de-sanitization after.
This is an override of the base class method.
Parameters
----------
Args:
prompt: The prompt to pass into the model.
Returns
-------
Returns:
The string generated by the model.
Example
-------
Example:
.. code-block:: python
response = prompt_guard_llm("Tell me a joke.")
"""
import promptguard as pg
_run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
# sanitize the prompt by replacing the sensitive information with a placeholder
sanitize_response: pg.SanitizeResponse = pg.sanitize(prompt)
sanitized_prompt_value_str = sanitize_response.sanitized_text
# call the LLM with the sanitized prompt and get the response
llm_response = self.llm.generate_prompt(
[StringPromptValue(text=sanitized_prompt_value_str)],
llm_response = self.base_llm.predict(
sanitized_prompt_value_str, stop=stop, callbacks=_run_manager.get_child()
)
# desanitize the response by restoring the original sensitive information
desanitize_response: pg.DesanitizeResponse = pg.desanitize(
llm_response.generations[0][0].text,
llm_response,
secure_context=sanitize_response.secure_context,
)
return desanitize_response.desanitized_text

View File

@@ -1,7 +1,7 @@
import langchain.utilities.promptguard as pgf
from langchain import LLMChain, PromptTemplate
from langchain.llms import OpenAI
from langchain.llms.promptguard import PromptGuardLLMWrapper
from langchain.llms.promptguard import PromptGuard
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableMap
@@ -42,10 +42,10 @@ Question: ```{question}```
"""
def test_promptguard_llm_wrapper() -> None:
def test_promptguard() -> None:
chain = LLMChain(
prompt=PromptTemplate.from_template(prompt_template),
llm=PromptGuardLLMWrapper(llm=OpenAI()),
llm=PromptGuard(llm=OpenAI()),
memory=ConversationBufferWindowMemory(k=2),
)