experimental[patch] Update prompt injection model (#13930)

- **Description:** Existing model used for Prompt Injection is quite
outdated but we fine-tuned and open-source a new model based on the same
model deberta-v3-base from Microsoft -
[laiyer/deberta-v3-base-prompt-injection](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection).
It supports more up-to-date injections and less prone to
false-positives.
  - **Dependencies:** No
  - **Tag maintainer:** -
  - **Twitter handle:** @alex_yaremchuk

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Oleksandr Yaremchuk 2023-11-27 23:56:53 +01:00 committed by GitHub
parent e6ebde9688
commit c0277d06e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 15 deletions

View File

@ -8,7 +8,7 @@
"# Hugging Face prompt injection identification\n", "# Hugging Face prompt injection identification\n",
"\n", "\n",
"This notebook shows how to prevent prompt injection attacks using the text classification model from `HuggingFace`.\n", "This notebook shows how to prevent prompt injection attacks using the text classification model from `HuggingFace`.\n",
"It exploits the *deberta* model trained to identify prompt injections: https://huggingface.co/deepset/deberta-v3-base-injection" "By default it uses a *deberta* model trained to identify prompt injections. In this walkthrough we'll use https://huggingface.co/laiyer/deberta-v3-base-prompt-injection."
] ]
}, },
{ {
@ -21,19 +21,37 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "aea25588-3c3f-4506-9094-221b3a0d519b", "id": "aea25588-3c3f-4506-9094-221b3a0d519b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "58ab3557623a495d8cc3c3e32a61938f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [ "text/plain": [
"'hugging_face_injection_identifier'" "Downloading config.json: 0%| | 0.00/994 [00:00<?, ?B/s]"
] ]
}, },
"execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3bf062f02d304ab5a485a2a228b4cf41",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading model.safetensors: 0%| | 0.00/738M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
} }
], ],
"source": [ "source": [
@ -41,7 +59,10 @@
" HuggingFaceInjectionIdentifier,\n", " HuggingFaceInjectionIdentifier,\n",
")\n", ")\n",
"\n", "\n",
"injection_identifier = HuggingFaceInjectionIdentifier()\n", "# Using https://huggingface.co/laiyer/deberta-v3-base-prompt-injection\n",
"injection_identifier = HuggingFaceInjectionIdentifier(\n",
" model=\"laiyer/deberta-v3-base-prompt-injection\"\n",
")\n",
"injection_identifier.name" "injection_identifier.name"
] ]
}, },
@ -299,9 +320,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "poetry-venv",
"language": "python", "language": "python",
"name": "python3" "name": "poetry-venv"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -313,7 +334,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -1,16 +1,18 @@
"""Tool for the identification of prompt injection attacks.""" """Tool for the identification of prompt injection attacks."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field, root_validator
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Pipeline from transformers import Pipeline
def _model_default_factory() -> Pipeline: def _model_default_factory(
model_name: str = "deepset/deberta-v3-base-injection"
) -> Pipeline:
try: try:
from transformers import pipeline from transformers import pipeline
except ImportError as e: except ImportError as e:
@ -18,11 +20,11 @@ def _model_default_factory() -> Pipeline:
"Cannot import transformers, please install with " "Cannot import transformers, please install with "
"`pip install transformers`." "`pip install transformers`."
) from e ) from e
return pipeline("text-classification", model="deepset/deberta-v3-base-injection") return pipeline("text-classification", model=model_name)
class HuggingFaceInjectionIdentifier(BaseTool): class HuggingFaceInjectionIdentifier(BaseTool):
"""Tool that uses deberta-v3-base-injection to detect prompt injection attacks.""" """Tool that uses HF model to detect prompt injection attacks."""
name: str = "hugging_face_injection_identifier" name: str = "hugging_face_injection_identifier"
description: str = ( description: str = (
@ -30,7 +32,19 @@ class HuggingFaceInjectionIdentifier(BaseTool):
"Useful for when you need to ensure that prompt is free of injection attacks. " "Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user." "Input should be any message from the user."
) )
model: Pipeline = Field(default_factory=_model_default_factory) model: Any = Field(default_factory=_model_default_factory)
"""Model to use for prompt injection detection.
Can be specified as transformers Pipeline or string. String should correspond to the
model name of a text-classification transformers model. Defaults to
``deepset/deberta-v3-base-injection`` model.
"""
@root_validator(pre=True)
def validate_environment(cls, values: dict) -> dict:
if isinstance(values.get("model"), str):
values["model"] = _model_default_factory(model_name=values["model"])
return values
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
"""Use the tool.""" """Use the tool."""