feat: PromptGuard takes a list of str (#9948)

Recently we made the decision that PromptGuard takes a list of strings
instead of a string.
@ggroode implemented the integration change.

---------

Co-authored-by: ggroode <ggroode@berkeley.edu>
Co-authored-by: ggroode <46691276+ggroode@users.noreply.github.com>
This commit is contained in:
Zizhong Zhang 2023-08-29 12:22:30 -07:00 committed by GitHub
parent ede45f535e
commit 8bd7a9d18e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 8 deletions

View File

@ -90,8 +90,8 @@ class PromptGuard(LLM):
_run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForLLMRun.get_noop_manager()
# sanitize the prompt by replacing the sensitive information with a placeholder # sanitize the prompt by replacing the sensitive information with a placeholder
sanitize_response: pg.SanitizeResponse = pg.sanitize(prompt) sanitize_response: pg.SanitizeResponse = pg.sanitize([prompt])
sanitized_prompt_value_str = sanitize_response.sanitized_text sanitized_prompt_value_str = sanitize_response.sanitized_texts[0]
# TODO: Add in callbacks once child runs for LLMs are supported by LangSmith. # TODO: Add in callbacks once child runs for LLMs are supported by LangSmith.
# call the LLM with the sanitized prompt and get the response # call the LLM with the sanitized prompt and get the response

View File

@ -1,4 +1,3 @@
import json
from typing import Dict, Union from typing import Dict, Union
@ -41,9 +40,9 @@ def sanitize(
if isinstance(input, str): if isinstance(input, str):
# the input could be a string, so we sanitize the string # the input could be a string, so we sanitize the string
sanitize_response: pg.SanitizeResponse = pg.sanitize(input) sanitize_response: pg.SanitizeResponse = pg.sanitize([input])
return { return {
"sanitized_input": sanitize_response.sanitized_text, "sanitized_input": sanitize_response.sanitized_texts[0],
"secure_context": sanitize_response.secure_context, "secure_context": sanitize_response.secure_context,
} }
@ -54,13 +53,12 @@ def sanitize(
# get the values from the dict # get the values from the dict
for key in input: for key in input:
values.append(input[key]) values.append(input[key])
input_value_str = json.dumps(values)
# sanitize the values # sanitize the values
sanitize_values_response: pg.SanitizeResponse = pg.sanitize(input_value_str) sanitize_values_response: pg.SanitizeResponse = pg.sanitize(values)
# reconstruct the dict with the sanitized values # reconstruct the dict with the sanitized values
sanitized_input_values = json.loads(sanitize_values_response.sanitized_text) sanitized_input_values = sanitize_values_response.sanitized_texts
idx = 0 idx = 0
sanitized_input = dict() sanitized_input = dict()
for key in input: for key in input: