mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
Improve prompt injection detection (#14842)
- **Description:** This is addition to [my previous PR](https://github.com/langchain-ai/langchain/pull/13930) with improvements to flexibility allowing different models and notebook to use ONNX runtime for faster speed. Since the last PR, [our model](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection) got more than 660k downloads, and with the [public benchmark](https://huggingface.co/spaces/laiyer/prompt-injection-benchmark) showed much fewer false-positives than the previous one from deepset. Additionally, on the ONNX runtime, it can be running 3x faster on the CPU, which might be handy for builders using Langchain. **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** N/A - **Twitter handle:** `@laiyer_ai`
This commit is contained in:
committed by
GitHub
parent
f8dccaa027
commit
d82a3828f2
@@ -1,7 +1,7 @@
|
||||
"""Tool for the identification of prompt injection attacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
@@ -10,17 +10,39 @@ if TYPE_CHECKING:
|
||||
from transformers import Pipeline
|
||||
|
||||
|
||||
class PromptInjectionException(ValueError):
|
||||
def __init__(self, message="Prompt injection attack detected", score: float = 1.0):
|
||||
self.message = message
|
||||
self.score = score
|
||||
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _model_default_factory(
|
||||
model_name: str = "deepset/deberta-v3-base-injection"
|
||||
model_name: str = "laiyer/deberta-v3-base-prompt-injection",
|
||||
) -> Pipeline:
|
||||
try:
|
||||
from transformers import pipeline
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
pipeline,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import transformers, please install with "
|
||||
"`pip install transformers`."
|
||||
) from e
|
||||
return pipeline("text-classification", model=model_name)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||
|
||||
return pipeline(
|
||||
"text-classification",
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_length=512, # default length of BERT models
|
||||
truncation=True, # otherwise it will fail on long prompts
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
@@ -32,13 +54,26 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
||||
"Input should be any message from the user."
|
||||
)
|
||||
model: Any = Field(default_factory=_model_default_factory)
|
||||
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
|
||||
"""Model to use for prompt injection detection.
|
||||
|
||||
Can be specified as transformers Pipeline or string. String should correspond to the
|
||||
model name of a text-classification transformers model. Defaults to
|
||||
``deepset/deberta-v3-base-injection`` model.
|
||||
``laiyer/deberta-v3-base-prompt-injection`` model.
|
||||
"""
|
||||
threshold: float = Field(
|
||||
description="Threshold for prompt injection detection.", default=0.5
|
||||
)
|
||||
"""Threshold for prompt injection detection.
|
||||
|
||||
Defaults to 0.5."""
|
||||
injection_label: str = Field(
|
||||
description="Label of the injection for prompt injection detection.",
|
||||
default="INJECTION",
|
||||
)
|
||||
"""Label for prompt injection detection model.
|
||||
|
||||
Defaults to ``INJECTION``. Value depends on the model used."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
@@ -49,7 +84,12 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
result = self.model(query)
|
||||
result = sorted(result, key=lambda x: x["score"], reverse=True)
|
||||
if result[0]["label"] == "INJECTION":
|
||||
raise ValueError("Prompt injection attack detected")
|
||||
score = (
|
||||
result[0]["score"]
|
||||
if result[0]["label"] == self.injection_label
|
||||
else 1 - result[0]["score"]
|
||||
)
|
||||
if score > self.threshold:
|
||||
raise PromptInjectionException("Prompt injection attack detected", score)
|
||||
|
||||
return query
|
||||
|
Reference in New Issue
Block a user