mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
experimental[patch] Update prompt injection model (#13930)
- **Description:** Existing model used for Prompt Injection is quite outdated but we fine-tuned and open-source a new model based on the same model deberta-v3-base from Microsoft - [laiyer/deberta-v3-base-prompt-injection](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection). It supports more up-to-date injections and less prone to false-positives. - **Dependencies:** No - **Tag maintainer:** - - **Twitter handle:** @alex_yaremchuk --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
e6ebde9688
commit
c0277d06e8
@ -8,7 +8,7 @@
|
|||||||
"# Hugging Face prompt injection identification\n",
|
"# Hugging Face prompt injection identification\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook shows how to prevent prompt injection attacks using the text classification model from `HuggingFace`.\n",
|
"This notebook shows how to prevent prompt injection attacks using the text classification model from `HuggingFace`.\n",
|
||||||
"It exploits the *deberta* model trained to identify prompt injections: https://huggingface.co/deepset/deberta-v3-base-injection"
|
"By default it uses a *deberta* model trained to identify prompt injections. In this walkthrough we'll use https://huggingface.co/laiyer/deberta-v3-base-prompt-injection."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -21,19 +21,37 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"id": "aea25588-3c3f-4506-9094-221b3a0d519b",
|
"id": "aea25588-3c3f-4506-9094-221b3a0d519b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "58ab3557623a495d8cc3c3e32a61938f",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'hugging_face_injection_identifier'"
|
"Downloading config.json: 0%| | 0.00/994 [00:00<?, ?B/s]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "3bf062f02d304ab5a485a2a228b4cf41",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Downloading model.safetensors: 0%| | 0.00/738M [00:00<?, ?B/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
@ -41,7 +59,10 @@
|
|||||||
" HuggingFaceInjectionIdentifier,\n",
|
" HuggingFaceInjectionIdentifier,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"injection_identifier = HuggingFaceInjectionIdentifier()\n",
|
"# Using https://huggingface.co/laiyer/deberta-v3-base-prompt-injection\n",
|
||||||
|
"injection_identifier = HuggingFaceInjectionIdentifier(\n",
|
||||||
|
" model=\"laiyer/deberta-v3-base-prompt-injection\"\n",
|
||||||
|
")\n",
|
||||||
"injection_identifier.name"
|
"injection_identifier.name"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -299,9 +320,9 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3 (ipykernel)",
|
"display_name": "poetry-venv",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "poetry-venv"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
@ -313,7 +334,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.9.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
"""Tool for the identification of prompt injection attacks."""
|
"""Tool for the identification of prompt injection attacks."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field, root_validator
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
|
|
||||||
|
|
||||||
def _model_default_factory() -> Pipeline:
|
def _model_default_factory(
|
||||||
|
model_name: str = "deepset/deberta-v3-base-injection"
|
||||||
|
) -> Pipeline:
|
||||||
try:
|
try:
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -18,11 +20,11 @@ def _model_default_factory() -> Pipeline:
|
|||||||
"Cannot import transformers, please install with "
|
"Cannot import transformers, please install with "
|
||||||
"`pip install transformers`."
|
"`pip install transformers`."
|
||||||
) from e
|
) from e
|
||||||
return pipeline("text-classification", model="deepset/deberta-v3-base-injection")
|
return pipeline("text-classification", model=model_name)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceInjectionIdentifier(BaseTool):
|
class HuggingFaceInjectionIdentifier(BaseTool):
|
||||||
"""Tool that uses deberta-v3-base-injection to detect prompt injection attacks."""
|
"""Tool that uses HF model to detect prompt injection attacks."""
|
||||||
|
|
||||||
name: str = "hugging_face_injection_identifier"
|
name: str = "hugging_face_injection_identifier"
|
||||||
description: str = (
|
description: str = (
|
||||||
@ -30,7 +32,19 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
|||||||
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
||||||
"Input should be any message from the user."
|
"Input should be any message from the user."
|
||||||
)
|
)
|
||||||
model: Pipeline = Field(default_factory=_model_default_factory)
|
model: Any = Field(default_factory=_model_default_factory)
|
||||||
|
"""Model to use for prompt injection detection.
|
||||||
|
|
||||||
|
Can be specified as transformers Pipeline or string. String should correspond to the
|
||||||
|
model name of a text-classification transformers model. Defaults to
|
||||||
|
``deepset/deberta-v3-base-injection`` model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: dict) -> dict:
|
||||||
|
if isinstance(values.get("model"), str):
|
||||||
|
values["model"] = _model_default_factory(model_name=values["model"])
|
||||||
|
return values
|
||||||
|
|
||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
|
Loading…
Reference in New Issue
Block a user