From 94becbf19260f77a83a1269635e585ef232da749 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Fri, 18 Aug 2023 18:31:39 -0700 Subject: [PATCH] cr --- libs/langchain/langchain/callbacks/manager.py | 4 +- libs/langchain/langchain/llms/__init__.py | 3 + libs/langchain/langchain/llms/promptguard.py | 74 ++++++++++--------- .../llms/test_promptguard.py | 6 +- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index 7016a13ed95..482a59b1d0e 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -582,7 +582,7 @@ class AsyncParentRunManager(AsyncRunManager): return manager -class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): +class CallbackManagerForLLMRun(ParentRunManager, LLMManagerMixin): """Callback manager for LLM run.""" def on_llm_new_token( @@ -645,7 +645,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): ) -class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): +class AsyncCallbackManagerForLLMRun(AsyncParentRunManager, LLMManagerMixin): """Async callback manager for LLM run.""" async def on_llm_new_token( diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index b2728535a63..d46ce1ab0c7 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -69,6 +69,7 @@ from langchain.llms.petals import Petals from langchain.llms.pipelineai import PipelineAI from langchain.llms.predibase import Predibase from langchain.llms.predictionguard import PredictionGuard +from langchain.llms.promptguard import PromptGuard from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat from langchain.llms.replicate import Replicate from langchain.llms.rwkv import RWKV @@ -141,6 +142,7 @@ __all__ = [ "PredictionGuard", "PromptLayerOpenAI", "PromptLayerOpenAIChat", + "PromptGuard", "RWKV", "Replicate", "SagemakerEndpoint", @@ -205,6 +207,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "petals": Petals, "pipelineai": PipelineAI, "predibase": Predibase, + "promptguard": PromptGuard, "replicate": Replicate, "rwkv": RWKV, "sagemaker_endpoint": SagemakerEndpoint, diff --git a/libs/langchain/langchain/llms/promptguard.py b/libs/langchain/langchain/llms/promptguard.py index 200d9b5e3d1..82ae7ad15bd 100644 --- a/libs/langchain/langchain/llms/promptguard.py +++ b/libs/langchain/langchain/llms/promptguard.py @@ -3,31 +3,34 @@ from typing import Any, Dict, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.prompts.base import StringPromptValue from langchain.pydantic_v1 import Extra, root_validator +from langchain.schema.language_model import BaseLanguageModel from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) -class PromptGuardLLMWrapper(LLM): - """ - An LLM wrapper that uses PromptGuard to sanitize the prompt before - passing it to the LLM, and that desanitizes the response after - getting it from the LLM. +class PromptGuard(LLM): + """An LLM wrapper that uses PromptGuard to sanitize prompts. - To use, you should have the `promptguard` python package installed, - and the environment variable `PROMPTGUARD_API_KEY` set with + Wraps another LLM and sanitizes prompts before passing it to the LLM, then + de-sanitizes the response. + + To use, you should have the ``promptguard`` python package installed, + and the environment variable ``PROMPTGUARD_API_KEY`` set with your API key, or pass it as a named parameter to the constructor. Example: .. code-block:: python - prompt_guard_llm = PromptGuardLLM(llm=ChatOpenAI()) + from langchain.llms import PromptGuardLLM + from langchain.chat_models import ChatOpenAI + + prompt_guard_llm = PromptGuardLLM(base_llm=ChatOpenAI()) """ - llm: Any - """The LLM to use.""" + base_llm: BaseLanguageModel + """The base LLM to use.""" class Config: """Configuration for this pydantic object.""" @@ -37,25 +40,29 @@ class PromptGuardLLMWrapper(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validates that the PromptGuard API key and the Python package exist.""" + try: + import promptguard as pg + except ImportError: + raise ImportError( + "Could not import the `promptguard` Python package, " + "please install it with `pip install promptguard`." + ) + if pg.__package__ is None: + raise ValueError( + "Could not properly import `promptguard`, " + "promptguard.__package__ is None." + ) + api_key = get_from_dict_or_env( - values, "promptguard_api_key", "PROMPTGUARD_API_KEY" + values, "promptguard_api_key", "PROMPTGUARD_API_KEY", default="" ) - if api_key is None: + if not api_key: raise ValueError( "Could not find PROMPTGUARD_API_KEY in the environment. " "Please set it to your PromptGuard API key." "You can get it by creating an account on the PromptGuard website: " "https://promptguard.opaque.co/ ." ) - try: - import promptguard as pg - - assert pg.__package__ is not None - except ImportError: - raise ImportError( - "Could not import the `promptguard` Python package, " - "please install it with `pip install promptguard`." - ) return values def _call( @@ -65,38 +72,35 @@ class PromptGuardLLMWrapper(LLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Use PromptGuard to do sanitization and desanitization - before and after running LLM. + """Call base LLM with sanitization before and de-sanitization after. - This is an override of the base class method. - - Parameters - ---------- + Args: prompt: The prompt to pass into the model. - Returns - ------- + Returns: The string generated by the model. - Example - ------- + Example: .. code-block:: python + response = prompt_guard_llm("Tell me a joke.") """ import promptguard as pg + _run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager() + # sanitize the prompt by replacing the sensitive information with a placeholder sanitize_response: pg.SanitizeResponse = pg.sanitize(prompt) sanitized_prompt_value_str = sanitize_response.sanitized_text # call the LLM with the sanitized prompt and get the response - llm_response = self.llm.generate_prompt( - [StringPromptValue(text=sanitized_prompt_value_str)], + llm_response = self.base_llm.predict( + sanitized_prompt_value_str, stop=stop, callbacks=_run_manager.get_child() ) # desanitize the response by restoring the original sensitive information desanitize_response: pg.DesanitizeResponse = pg.desanitize( - llm_response.generations[0][0].text, + llm_response, secure_context=sanitize_response.secure_context, ) return desanitize_response.desanitized_text diff --git a/libs/langchain/tests/integration_tests/llms/test_promptguard.py b/libs/langchain/tests/integration_tests/llms/test_promptguard.py index b37e1714f75..599df595a0a 100644 --- a/libs/langchain/tests/integration_tests/llms/test_promptguard.py +++ b/libs/langchain/tests/integration_tests/llms/test_promptguard.py @@ -1,7 +1,7 @@ import langchain.utilities.promptguard as pgf from langchain import LLMChain, PromptTemplate from langchain.llms import OpenAI -from langchain.llms.promptguard import PromptGuardLLMWrapper +from langchain.llms.promptguard import PromptGuard from langchain.memory import ConversationBufferWindowMemory from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnableMap @@ -42,10 +42,10 @@ Question: ```{question}``` """ -def test_promptguard_llm_wrapper() -> None: +def test_promptguard() -> None: chain = LLMChain( prompt=PromptTemplate.from_template(prompt_template), - llm=PromptGuardLLMWrapper(llm=OpenAI()), + llm=PromptGuard(llm=OpenAI()), memory=ConversationBufferWindowMemory(k=2), )