Improve prompt injection detection (#14842)

- **Description:** This is addition to [my previous
PR](https://github.com/langchain-ai/langchain/pull/13930) with
improvements to flexibility allowing different models and notebook to
use ONNX runtime for faster speed. Since the last PR, [our
model](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection)
got more than 660k downloads, and with the [public
benchmark](https://huggingface.co/spaces/laiyer/prompt-injection-benchmark)
showed much fewer false-positives than the previous one from deepset.
Additionally, on the ONNX runtime, it can be running 3x faster on the
CPU, which might be handy for builders using Langchain.
 **Issue:** N/A
 - **Dependencies:** N/A
 - **Tag maintainer:** N/A 
- **Twitter handle:** `@laiyer_ai`
This commit is contained in:
Oleksandr Yaremchuk
2023-12-19 02:50:21 +01:00
committed by GitHub
parent f8dccaa027
commit d82a3828f2
2 changed files with 131 additions and 51 deletions

View File

@@ -1,7 +1,7 @@
"""Tool for the identification of prompt injection attacks."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Union
from langchain.pydantic_v1 import Field, root_validator
from langchain.tools.base import BaseTool
@@ -10,17 +10,39 @@ if TYPE_CHECKING:
from transformers import Pipeline
class PromptInjectionException(ValueError):
def __init__(self, message="Prompt injection attack detected", score: float = 1.0):
self.message = message
self.score = score
super().__init__(self.message)
def _model_default_factory(
model_name: str = "deepset/deberta-v3-base-injection"
model_name: str = "laiyer/deberta-v3-base-prompt-injection",
) -> Pipeline:
try:
from transformers import pipeline
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
pipeline,
)
except ImportError as e:
raise ImportError(
"Cannot import transformers, please install with "
"`pip install transformers`."
) from e
return pipeline("text-classification", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
max_length=512, # default length of BERT models
truncation=True, # otherwise it will fail on long prompts
)
class HuggingFaceInjectionIdentifier(BaseTool):
@@ -32,13 +54,26 @@ class HuggingFaceInjectionIdentifier(BaseTool):
"Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user."
)
model: Any = Field(default_factory=_model_default_factory)
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
"""Model to use for prompt injection detection.
Can be specified as transformers Pipeline or string. String should correspond to the
model name of a text-classification transformers model. Defaults to
``deepset/deberta-v3-base-injection`` model.
``laiyer/deberta-v3-base-prompt-injection`` model.
"""
threshold: float = Field(
description="Threshold for prompt injection detection.", default=0.5
)
"""Threshold for prompt injection detection.
Defaults to 0.5."""
injection_label: str = Field(
description="Label of the injection for prompt injection detection.",
default="INJECTION",
)
"""Label for prompt injection detection model.
Defaults to ``INJECTION``. Value depends on the model used."""
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
@@ -49,7 +84,12 @@ class HuggingFaceInjectionIdentifier(BaseTool):
def _run(self, query: str) -> str:
"""Use the tool."""
result = self.model(query)
result = sorted(result, key=lambda x: x["score"], reverse=True)
if result[0]["label"] == "INJECTION":
raise ValueError("Prompt injection attack detected")
score = (
result[0]["score"]
if result[0]["label"] == self.injection_label
else 1 - result[0]["score"]
)
if score > self.threshold:
raise PromptInjectionException("Prompt injection attack detected", score)
return query